Skip to content

Commit

Permalink
Adapts stripe checkout to function for random token number as well
Browse files Browse the repository at this point in the history
  • Loading branch information
emmdim committed Nov 27, 2024
1 parent 471dad8 commit 77defb2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
10 changes: 5 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func main() {
flag.Duration("waitPeriod", 1*time.Hour, "wait period between requests for the same user")
flag.StringP("dbType", "t", db.TypePebble, fmt.Sprintf("key-value db type [%s,%s,%s]", db.TypePebble, db.TypeLevelDB, db.TypeMongo))
flag.String("stripeKey", "", "stripe secret key")
flag.String("stripePriceId", "", "stripe price id")
flag.String("stripeProductID", "", "stripe price id")
flag.Int64("stripeMinQuantity", 100, "stripe min number of tokens")
flag.Int64("stripeMaxQuantity", 100000, "stripe max number of tokens")
flag.String("stripeWebhookSecret", "", "stripe webhook secret key")
Expand Down Expand Up @@ -98,7 +98,7 @@ func main() {
if err := viper.BindPFlag("stripeKey", flag.Lookup("stripeKey")); err != nil {
panic(err)
}
if err := viper.BindPFlag("stripePriceId", flag.Lookup("stripePriceId")); err != nil {
if err := viper.BindPFlag("stripeProductID", flag.Lookup("stripeProductID")); err != nil {
panic(err)
}
if err := viper.BindPFlag("stripeMinQuantity", flag.Lookup("stripeMinQuantity")); err != nil {
Expand Down Expand Up @@ -149,7 +149,7 @@ func main() {
waitPeriod := viper.GetDuration("waitPeriod")
dbType := viper.GetString("dbType")
stripeKey := viper.GetString("stripeKey")
stripePriceId := viper.GetString("stripePriceId")
stripeProductID := viper.GetString("stripeProductID")
stripeMinQuantity := viper.GetInt64("stripeMinQuantity")
stripeMaxQuantity := viper.GetInt64("stripeMaxQuantity")
stripeWebhookSecret := viper.GetString("stripeWebhookSecret")
Expand Down Expand Up @@ -215,7 +215,7 @@ func main() {
if amount := f.AuthTypes[faucet.AuthTypeStripe]; amount > 0 {
s, err = stripehandler.NewStripeClient(
stripeKey,
stripePriceId,
stripeProductID,
stripeWebhookSecret,
stripeMinQuantity,
stripeMaxQuantity,
Expand All @@ -226,7 +226,7 @@ func main() {
if err != nil {
log.Fatalf("stripe initialization error: %s", err)
} else {
log.Infof("stripe enabled with price id %s", stripePriceId)
log.Infof("stripe enabled with price id %s", stripeProductID)
}
}

Expand Down
64 changes: 49 additions & 15 deletions stripehandler/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"sort"
"sync"

"github.com/stripe/stripe-go/v81"
Expand All @@ -18,7 +20,7 @@ import (
// StripeHandler represents the configuration for the stripe a provider for handling Stripe payments.
type StripeHandler struct {
Key string // The API key for the Stripe account.
PriceId string // The ID of the price associated with the product.
ProductID string // The ID of the price associated with the product.
MinQuantity int64 // The minimum quantity allowed for the product.
MaxQuantity int64 // The maximum quantity allowed for the product.
DefaultAmount int64 // The default amount for the product.
Expand All @@ -40,13 +42,13 @@ type ReturnStatus struct {
// NewStripeClient creates a new instance of the StripeHandler struct with the provided parameters.
// It sets the Stripe API key, price ID, webhook secret, minimum quantity, maximum quantity, and default amount.
// Returns a pointer to the created StripeHandler.
func NewStripeClient(key, priceId, webhookSecret string, minQuantity, maxQuantity, defaultAmount int64, faucet *faucet.Faucet, storage *storage.Storage) (*StripeHandler, error) {
if key == "" || priceId == "" || webhookSecret == "" || storage == nil {
func NewStripeClient(key, productID, webhookSecret string, minQuantity, maxQuantity, defaultAmount int64, faucet *faucet.Faucet, storage *storage.Storage) (*StripeHandler, error) {
if key == "" || productID == "" || webhookSecret == "" || storage == nil {
return nil, errors.New("missing required parameters")
}
stripe.Key = key
return &StripeHandler{
PriceId: priceId,
ProductID: productID,
MinQuantity: minQuantity,
MaxQuantity: maxQuantity,
DefaultAmount: defaultAmount,
Expand All @@ -64,25 +66,57 @@ func NewStripeClient(key, priceId, webhookSecret string, minQuantity, maxQuantit
// The function constructs a stripe.CheckoutSessionParams object with the provided parameters and creates a new session using the session.New function.
// If the session creation is successful, it returns the session pointer, otherwise it returns an error.
func (s *StripeHandler) CreateCheckoutSession(defaultAmount int64, to, returnURL, referral string) (*stripe.CheckoutSession, error) {
// search corresponding price tokens package
packName := fmt.Sprintf("pack_%d", defaultAmount)
priceParams := &stripe.PriceListParams{Active: stripe.Bool(true), LookupKeys: []*string{stripe.String(packName)}}
priceList := price.List(priceParams).PriceList()
// iterate price result
if len(priceList.Data) == 0 {
if defaultAmount <= 0 {
return nil, nil
}
params := &stripe.CheckoutSessionParams{
// get the different price packages
priceSearchParams := &stripe.PriceSearchParams{
SearchParams: stripe.SearchParams{
Query: fmt.Sprintf("product:'%s' AND active:'true'", s.ProductID),
},
}
priceSearchParams.Limit = stripe.Int64(100)
result := price.Search(priceSearchParams)
if result.Err() != nil {
return nil, result.Err()
}
var prices []stripe.Price
for result.Next() {
prices = append(prices, *result.Price())
}
var closestRoundedPrice *stripe.Price
// sorting prices in order to find the closest price to the default amount
sort.Slice(prices,
func(i, j int) bool {
return prices[i].TransformQuantity.DivideBy < prices[j].TransformQuantity.DivideBy
})
// find the closest price under the default amount
index := sort.Search(len(prices), func(i int) bool {
return prices[i].TransformQuantity.DivideBy > defaultAmount
})
if index == 0 {
closestRoundedPrice = &prices[0]
} else if index <= len(prices) {
closestRoundedPrice = &prices[index-1]
}
if closestRoundedPrice == nil {
return nil, nil
}
// calculate the price per token according to the package and
// round in order to fullfill the two decimals limits limitation of stripe
tempCalc := math.Round(float64(float64(closestRoundedPrice.UnitAmount) / float64(closestRoundedPrice.TransformQuantity.DivideBy) * float64(defaultAmount)))

checkoutParams := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(to),
UIMode: stripe.String("embedded"),
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
ReturnURL: stripe.String(returnURL + "/{CHECKOUT_SESSION_ID}"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
PriceData: &stripe.CheckoutSessionLineItemPriceDataParams{
Product: &priceList.Data[0].Product.ID,
Currency: (*string)(&priceList.Data[0].Currency),
UnitAmountDecimal: &priceList.Data[0].UnitAmountDecimal,
Product: &s.ProductID,
Currency: stripe.String(string(closestRoundedPrice.Currency)),
UnitAmountDecimal: stripe.Float64(tempCalc),
},
Quantity: stripe.Int64(1),
},
Expand All @@ -92,7 +126,7 @@ func (s *StripeHandler) CreateCheckoutSession(defaultAmount int64, to, returnURL
"referral": referral,
},
}
ses, err := session.New(params)
ses, err := session.New(checkoutParams)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 77defb2

Please sign in to comment.