Skip to content

Commit

Permalink
fix: add promotion code support to subscription creation (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhonylucas74 authored Sep 11, 2024
1 parent f3d02e7 commit 0ac4dd4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 13 deletions.
81 changes: 81 additions & 0 deletions backend/apps/account_payment/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,39 @@ def mutate(cls, root, info, price_id, coupon=None):

price = DJStripePrice.objects.get(djstripe_id=price_id)
is_trial_active = len(internal_subscriptions) == 0
promotion_code = None

try:
promotion = get_stripe_promo(coupon)
if promotion and promotion.active:
promotion_code = promotion.id
except Exception:
...

customer, _ = DJStripeCustomer.get_or_create(admin)
price_id = price.id

if promotion_code:
discounts = [{"promotion_code": promotion_code}]
else:
discounts = []

if is_trial_active:
subscription = None
setup_intent = SetupIntent.create(
customer=customer.id,
usage="off_session",
metadata={
"price_id": price_id,
"promotion_code": promotion_code,
},
)
else:
subscription: DJStripeSubscription = customer.subscribe(
price=price_id,
payment_behavior="default_incomplete",
payment_settings={"save_default_payment_method": "on_subscription"},
discounts=discounts,
)

if subscription:
Expand All @@ -237,6 +252,49 @@ def mutate(cls, root, info, price_id, coupon=None):
return cls(errors=[str(e)])


class StripeCouponValidationMutation(Mutation):
"""Validate a Stripe coupon and return discount details"""

is_valid = Boolean()
discount_amount = Float()
errors = List(String)

class Arguments:
coupon = String(required=True)
price_id = ID(required=True)

@classmethod
@login_required
def mutate(cls, root, info, coupon, price_id):
try:
try:
promotion_code_object = get_stripe_promo(coupon)
if promotion_code_object and promotion_code_object.active:
coupon_object = stripe.Coupon.retrieve(promotion_code_object.coupon.id)
else:
return cls(is_valid=False, discount_amount=0, errors=["Cupom inválido"])
except Exception as e:
return cls(is_valid=False, discount_amount=0, errors=["Cupom inválido", str(e)])

if not coupon_object.valid:
return cls(is_valid=False, discount_amount=0, errors=["Cupom não está ativo"])

price = DJStripePrice.objects.get(djstripe_id=price_id)
price_amount = price.unit_amount / 100.0

discount_amount = 0.0

if coupon_object.amount_off:
discount_amount = coupon_object.amount_off / 100.0
elif coupon_object.percent_off:
discount_amount = (coupon_object.percent_off / 100.0) * price_amount

return cls(is_valid=True, discount_amount=discount_amount)
except Exception as e:
logger.error(e)
return cls(is_valid=False, errors=[str(e)])


class StripeSubscriptionDeleteMutation(Mutation):
"""Delete stripe subscription"""

Expand Down Expand Up @@ -319,6 +377,28 @@ def mutate(cls, root, info, account_id, subscription_id):
return cls(errors=[str(e)])


def get_stripe_promo(promotion_code):
"""
Helper function to retrieve a Stripe Promotion Code by its code.
:param promotion_code: The code of the promotion to be retrieved.
:return: The Stripe Promotion Code object if found.
:raises Exception: If the promotion code is not found or any error occurs.
"""
if not promotion_code:
raise Exception("Promotion code not provided")
try:
promotion_code_list = stripe.PromotionCode.list(code=promotion_code, limit=1)

if promotion_code_list.data:
return promotion_code_list.data[0]
else:
raise Exception("Promotion code not found")

except Exception as e:
raise Exception(f"Error retrieving promotion code: {str(e)}")


class Query(ObjectType):
stripe_price = PlainTextNode.Field(StripePriceNode)
all_stripe_price = DjangoFilterConnectionField(StripePriceNode)
Expand All @@ -331,6 +411,7 @@ class Mutation(ObjectType):
delete_stripe_subscription = StripeSubscriptionDeleteMutation.Field()
create_stripe_customer_subscription = StripeSubscriptionCustomerCreateMutation.Field()
update_stripe_customer_subscription = StripeSubscriptionCustomerDeleteMutation.Field()
validate_stripe_coupon = StripeCouponValidationMutation.Field()


# Reference
Expand Down
11 changes: 7 additions & 4 deletions backend/apps/account_payment/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def setup_intent_succeeded(event: Event, **kwargs):
setup_intent = event.data["object"]
metadata = setup_intent.get("metadata")
price_id = metadata.get("price_id")
promotion_code = metadata.get("promotion_code")

StripeCustomer.modify(
customer.id, invoice_settings={"default_payment_method": setup_intent.get("payment_method")}
Expand All @@ -199,12 +200,14 @@ def setup_intent_succeeded(event: Event, **kwargs):
subscriptions = StripeSubscription.list(customer=customer.id)
has_subscription = len(subscriptions.get("data")) > 0

if promotion_code:
discounts = [{"promotion_code": promotion_code}]
else:
discounts = []

if not has_subscription and price_id:
logger.info(f"Add subscription to user {event.customer.email}")
customer.subscribe(
price=price_id,
trial_period_days=7,
)
customer.subscribe(price=price_id, trial_period_days=7, discounts=discounts)


# Reference
Expand Down
11 changes: 2 additions & 9 deletions backend/custom/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
LOGGER_LEVEL = getenv("LOGGER_LEVEL", "DEBUG")
LOGGER_IGNORE = getenv("LOGGER_IGNORE", "").split(",")
LOGGER_SERIALIZE = bool(getenv("LOGGER_SERIALIZE", False))
LOGGER_FORMAT = "[{time:YYYY-MM-DD HH:mm:ss}] <lvl>{extra[app_name]}: {message}</>"
LOGGER_FORMAT = "[{time:YYYY-MM-DD HH:mm:ss}] <lvl>{message}</>"


class InterceptHandler(Handler):
Expand All @@ -28,14 +28,7 @@ def emit(self, record: LogRecord):
frame = frame.f_back
depth += 1

# Include the logger name (app name) in the log record
app_name = record.name
extra = record.__dict__.get("extra", {})
extra["app_name"] = app_name

logger.bind(**extra).opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())


def setup_logger(logging_settings=None):
Expand Down

0 comments on commit 0ac4dd4

Please sign in to comment.