diff --git a/pay-api/src/pay_api/services/refund.py b/pay-api/src/pay_api/services/refund.py index 0dd6fa5ef..fa003f202 100644 --- a/pay-api/src/pay_api/services/refund.py +++ b/pay-api/src/pay_api/services/refund.py @@ -301,6 +301,14 @@ def _validate_partial_refund_lines(refund_revenue: List[RefundPartialLine], invo raise BusinessException(Error.REFUND_PAYMENT_LINE_ITEM_INVALID) RefundService._validate_refund_amount(refund_line, payment_line) + @classmethod + def _validate_allow_partial_refund(cls, refund_revenue, invoice: InvoiceModel): + if refund_revenue: + if not flags.is_on("enable-partial-refunds", default=False): + raise BusinessException(Error.INVALID_REQUEST) + if invoice.corp_type.has_partner_disbursements: + raise BusinessException(Error.PARTIAL_REFUND_DISBURSEMENTS_UNSUPPORTED) + @classmethod @user_context def create_refund(cls, invoice_id: int, request: Dict[str, str], **kwargs) -> Dict[str, str]: @@ -331,8 +339,7 @@ def create_refund(cls, invoice_id: int, request: Dict[str, str], **kwargs) -> Di ) payment_account = PaymentAccount.find_by_id(invoice.payment_account_id) refund_revenue = (request or {}).get("refundRevenue", None) - if refund_revenue and not flags.is_on("enable-partial-refunds", default=False): - raise BusinessException(Error.INVALID_REQUEST) + cls._validate_allow_partial_refund(refund_revenue, invoice) refund_partial_lines = cls._get_partial_refund_lines(refund_revenue) cls._validate_partial_refund_lines(refund_partial_lines, invoice) diff --git a/pay-api/src/pay_api/utils/errors.py b/pay-api/src/pay_api/utils/errors.py index 87687ec98..c0f53d5ee 100644 --- a/pay-api/src/pay_api/utils/errors.py +++ b/pay-api/src/pay_api/utils/errors.py @@ -151,6 +151,7 @@ class Error(Enum): ) REFUND_PAYMENT_LINE_ITEM_INVALID = "REFUND_PAYMENT_LINE_ITEM_INVALID", HTTPStatus.BAD_REQUEST REFUND_AMOUNT_INVALID = "REFUND_AMOUNT_INVALID", HTTPStatus.BAD_REQUEST + PARTIAL_REFUND_DISBURSEMENTS_UNSUPPORTED = "PARTIAL_REFUND_DISBURSEMENTS_UNSUPPORTED", HTTPStatus.BAD_REQUEST ROUTING_SLIP_REFUND = "ROUTING_SLIP_REFUND", HTTPStatus.BAD_REQUEST NO_FEE_REFUND = "NO_FEE_REFUND", HTTPStatus.BAD_REQUEST diff --git a/pay-api/src/pay_api/version.py b/pay-api/src/pay_api/version.py index 751976517..9c94c2b7d 100644 --- a/pay-api/src/pay_api/version.py +++ b/pay-api/src/pay_api/version.py @@ -22,4 +22,4 @@ Development release segment: .devN """ -__version__ = "1.22.7" # pylint: disable=invalid-name +__version__ = "1.22.8" # pylint: disable=invalid-name diff --git a/pay-api/tests/unit/api/test_partial_refund.py b/pay-api/tests/unit/api/test_partial_refund.py index 8c8c77e33..34b30d710 100644 --- a/pay-api/tests/unit/api/test_partial_refund.py +++ b/pay-api/tests/unit/api/test_partial_refund.py @@ -187,6 +187,45 @@ def test_create_refund_fails(session, client, jwt, app, monkeypatch): assert len(refunds_partial) == 0 +def test_refund_validation_for_disbursements(session, client, jwt, app, monkeypatch): + """Assert that the partial refund amount validation returns 400 when the invoice corp_type has disbursements.""" + token = jwt.create_jwt(get_claims(app_request=app), token_header) + headers = {"Authorization": f"Bearer {token}", "content-type": "application/json"} + + rv = client.post( + "/api/v1/payment-requests", + data=json.dumps(get_payment_request()), + headers=headers, + ) + inv_id = rv.json.get("id") + invoice: InvoiceModel = InvoiceModel.find_by_id(inv_id) + invoice.invoice_status_code = InvoiceStatus.PAID.value + invoice.corp_type_code = 'VS' + invoice.save() + + token = jwt.create_jwt(get_claims(app_request=app, role=Role.SYSTEM.value), token_header) + headers = {"Authorization": f"Bearer {token}", "content-type": "application/json"} + + payment_line_items: List[PaymentLineItemModel] = invoice.payment_line_items + refund_revenue = [ + { + "paymentLineItemId": payment_line_items[0].id, + "refundAmount": float(payment_line_items[0].filing_fees), + "refundType": RefundsPartialType.BASE_FEES.value, + } + ] + + with patch("pay_api.services.payment_service.flags.is_on", return_value=True): + rv = client.post( + f"/api/v1/payment-requests/{inv_id}/refunds", + data=json.dumps({"reason": "Test", "refundRevenue": refund_revenue}), + headers=headers, + ) + assert rv.status_code == 400 + assert rv.json.get("type") == Error.PARTIAL_REFUND_DISBURSEMENTS_UNSUPPORTED.name + assert RefundModel.find_by_invoice_id(inv_id) is None + + @pytest.mark.parametrize( "fee_type", [