Skip to content

Commit

Permalink
More performance enhancements (#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
seeker25 authored Dec 9, 2024
1 parent fd39f80 commit 8c1dd70
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 64 deletions.
154 changes: 90 additions & 64 deletions pay-api/src/pay_api/models/payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
import pytz
from flask import current_app
from marshmallow import fields
from sqlalchemy import Boolean, ForeignKey, String, and_, cast, func, or_
from sqlalchemy import Boolean, ForeignKey, String, and_, cast, func, or_, select
from sqlalchemy.dialects.postgresql import ARRAY, TEXT
from sqlalchemy.orm import relationship
from sqlalchemy.sql import select
from sqlalchemy.orm import contains_eager, lazyload, load_only, relationship

from pay_api.exceptions import BusinessException
from pay_api.utils.constants import DT_SHORT_FORMAT
Expand Down Expand Up @@ -252,50 +251,62 @@ def find_payments_to_consolidate(cls, auth_account_id: str):
def generate_base_transaction_query(cls):
"""Generate a base query."""
return (
db.session.query(
Invoice.id,
Invoice.payment_account_id,
Invoice.corp_type_code,
Invoice.created_on,
Invoice.payment_date,
Invoice.refund_date,
Invoice.invoice_status_code,
Invoice.total,
Invoice.service_fees,
Invoice.paid,
Invoice.refund,
Invoice.folio_number,
Invoice.created_name,
Invoice.invoice_status_code,
Invoice.payment_method_code,
Invoice.details,
Invoice.business_identifier,
Invoice.created_by,
Invoice.filing_id,
Invoice.bcol_account,
Invoice.disbursement_date,
Invoice.disbursement_reversal_date,
Invoice.overdue_date,
PaymentLineItem.id,
PaymentLineItem.description,
PaymentLineItem.gst,
PaymentLineItem.pst,
PaymentAccount.id,
PaymentAccount.auth_account_id,
PaymentAccount.name,
PaymentAccount.billable,
InvoiceReference.id,
InvoiceReference.invoice_number,
InvoiceReference.reference_number,
InvoiceReference.status_code,
)
.outerjoin(PaymentAccount, Invoice.payment_account_id == PaymentAccount.id)
.outerjoin(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
.outerjoin(
db.session.query(Invoice)
.join(PaymentAccount, Invoice.payment_account_id == PaymentAccount.id)
.join(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
.join(
FeeSchedule,
FeeSchedule.fee_schedule_id == PaymentLineItem.fee_schedule_id,
)
.outerjoin(InvoiceReference, InvoiceReference.invoice_id == Invoice.id)
.options(
lazyload("*"),
load_only(
Invoice.id,
Invoice.corp_type_code,
Invoice.created_on,
Invoice.payment_date,
Invoice.refund_date,
Invoice.invoice_status_code,
Invoice.total,
Invoice.service_fees,
Invoice.paid,
Invoice.refund,
Invoice.folio_number,
Invoice.created_name,
Invoice.invoice_status_code,
Invoice.payment_method_code,
Invoice.details,
Invoice.business_identifier,
Invoice.created_by,
Invoice.filing_id,
Invoice.bcol_account,
Invoice.disbursement_date,
Invoice.disbursement_reversal_date,
Invoice.overdue_date,
),
contains_eager(Invoice.payment_line_items)
.load_only(
PaymentLineItem.description,
PaymentLineItem.gst,
PaymentLineItem.pst,
PaymentLineItem.service_fees,
PaymentLineItem.total,
)
.contains_eager(PaymentLineItem.fee_schedule)
.load_only(FeeSchedule.filing_type_code),
contains_eager(Invoice.payment_account).load_only(
PaymentAccount.auth_account_id,
PaymentAccount.name,
PaymentAccount.billable,
PaymentAccount.branch_name,
),
contains_eager(Invoice.references).load_only(
InvoiceReference.invoice_number,
InvoiceReference.reference_number,
InvoiceReference.status_code,
),
)
)

@classmethod
Expand All @@ -320,7 +331,7 @@ def search_purchase_history( # noqa:E501; pylint:disable=too-many-arguments, to
count_future = executor.submit(cls.get_count, auth_account_id, search_filter)
sub_query = cls.generate_subquery(auth_account_id, search_filter, limit, page)
query = query.filter(Invoice.id.in_(sub_query.subquery().select())).order_by(Invoice.id.desc())
result_future = executor.submit(db.session.query(Invoice).from_statement(query).all)
result_future = executor.submit(query.all)
count = count_future.result()
result = result_future.result()
# If maximum number of records is provided, return it as total
Expand All @@ -330,16 +341,14 @@ def search_purchase_history( # noqa:E501; pylint:disable=too-many-arguments, to
# If maximum number of records is provided, set the page with that number
sub_query = cls.generate_subquery(auth_account_id, search_filter, max_no_records, page=None)
result, count = (
db.session.query(Invoice)
.from_statement(query.filter(Invoice.id.in_(sub_query.subquery().select())))
.all(),
query.filter(Invoice.id.in_(sub_query.subquery().select())).all(),
sub_query.count(),
)
else:
count = cls.get_count(auth_account_id, search_filter)
if count > 60000:
raise BusinessException(Error.PAYMENT_SEARCH_TOO_MANY_RECORDS)
result = db.session.query(Invoice).from_statement(query).all()
result = query.all()
return result, count

@classmethod
Expand Down Expand Up @@ -379,18 +388,15 @@ def get_invoices_and_payment_accounts_for_statements(cls, search_filter: Dict):
@classmethod
def get_count(cls, auth_account_id: str, search_filter: Dict):
"""Slimmed downed version for count (less joins)."""
query = cls.generate_base_transaction_query()
query = cls.filter(query, auth_account_id, search_filter)
count = query.group_by(Invoice.id).with_entities(func.count()).count() # pylint:disable=not-callable
query = db.session.query(func.distinct(Invoice.id))
query = cls.filter(query, auth_account_id, search_filter, include_joins=True)
count = query.count()
return count

@classmethod
def filter(cls, query, auth_account_id: str, search_filter: Dict):
def filter(cls, query, auth_account_id: str, search_filter: Dict, include_joins=False):
"""For filtering queries."""
if auth_account_id:
query = query.filter(PaymentAccount.auth_account_id == auth_account_id)
if account_name := search_filter.get("accountName", None):
query = query.filter(PaymentAccount.name.ilike(f"%{account_name}%"))
query = cls.filter_payment_account(query, auth_account_id, search_filter, include_joins)
if status_code := search_filter.get("statusCode", None):
query = query.filter(Invoice.invoice_status_code == status_code)
if search_filter.get("status", None):
Expand All @@ -408,14 +414,31 @@ def filter(cls, query, auth_account_id: str, search_filter: Dict):
if invoice_id := search_filter.get("id", None):
query = query.filter(cast(Invoice.id, String).like(f"%{invoice_id}%"))
if invoice_number := search_filter.get("invoiceNumber", None):
if include_joins:
query = query.join(InvoiceReference, InvoiceReference.invoice_id == Invoice.id)
query = query.filter(InvoiceReference.invoice_number.ilike(f"%{invoice_number}%"))

query = cls.filter_corp_type(query, search_filter)
query = cls.filter_payment(query, search_filter)
query = cls.filter_details(query, search_filter)
query = cls.filter_details(query, search_filter, include_joins)
query = cls.filter_date(query, search_filter)
return query

@classmethod
def filter_payment_account(cls, query, auth_account_id, search_filter: dict, include_joins=False):
"""Use subquery to look for payment accounts ahead of time, much faster and easier."""
account_name = search_filter.get("accountName", None)
if auth_account_id:
payment_account_id = (
db.session.query(PaymentAccount.id).filter(PaymentAccount.auth_account_id == auth_account_id).scalar()
)
query = query.filter(Invoice.payment_account_id == (payment_account_id or -1))
if account_name:
if include_joins:
query = query.join(PaymentAccount, PaymentAccount.id == Invoice.payment_account_id)
query = query.filter(PaymentAccount.name.ilike(f"%{account_name}%"))
return query

@classmethod
def filter_corp_type(cls, query, search_filter: dict):
"""Filter for corp type."""
Expand Down Expand Up @@ -470,9 +493,13 @@ def filter_date(cls, query, search_filter: dict):
return query

@classmethod
def filter_details(cls, query, search_filter: dict):
def filter_details(cls, query, search_filter: dict, include_joins=False):
"""Filter by details."""
if line_item := search_filter.get("lineItems", None):
line_item = search_filter.get("lineItems", None)
line_item_or_details = search_filter.get("lineItemsAndDetails", None)
if (line_item or line_item_or_details) and include_joins:
query = query.join(PaymentLineItem, PaymentLineItem.invoice_id == Invoice.id)
if line_item:
query = query.filter(PaymentLineItem.description.ilike(f"%{line_item}%"))
if details := search_filter.get("details", None):
query = query.filter(
Expand All @@ -485,7 +512,7 @@ def filter_details(cls, query, search_filter: dict):
),
)
)
if line_item_or_details := search_filter.get("lineItemsAndDetails", None):
if line_item_or_details:
query = query.filter(
or_(
PaymentLineItem.description.ilike(f"%{line_item_or_details}%"),
Expand All @@ -505,11 +532,10 @@ def filter_details(cls, query, search_filter: dict):
@classmethod
def generate_subquery(cls, auth_account_id, search_filter, limit, page):
"""Generate subquery for invoices, used for pagination."""
subquery = cls.generate_base_transaction_query()
subquery = db.session.query(Invoice.id)
subquery = (
cls.filter(subquery, auth_account_id, search_filter)
.with_entities(Invoice.id)
.group_by(Invoice.id)
cls.filter(subquery, auth_account_id, search_filter, include_joins=True)
.distinct()
.order_by(Invoice.id.desc())
)
if limit:
Expand Down
2 changes: 2 additions & 0 deletions pay-api/src/pay_api/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
class JSONPath(UserDefinedType):
"""Used to define json path when casting."""

cache_ok = True

@property
def python_type(self):
"""Return the python type."""
Expand Down

0 comments on commit 8c1dd70

Please sign in to comment.