Skip to content

Commit

Permalink
[BUG] Fix Control Message Utils & SQL Max Connections Exhaust (#1243)
Browse files Browse the repository at this point in the history
- Updated SQL loader to utilize connections from the pool.
- Fixed control message utility variable referenced before assignment error.
closes #1237 #1235

Authors:
  - Bhargav Suryadevara (https://github.com/bsuryadevara)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1243
  • Loading branch information
bsuryadevara authored Oct 5, 2023
1 parent 7aaec71 commit 1b6e9f2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
46 changes: 25 additions & 21 deletions morpheus/loaders/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import engine

import cudf

Expand All @@ -40,30 +41,27 @@ def _parse_query_data(
Parameters
----------
query_data : Dict[str, Union[str, Optional[Dict[str, Any]]]]
The dictionary containing the connection string, query, and params (optional).
The dictionary containing the query, and params (optional).
Returns
-------
Dict[str, Union[str, Optional[Dict[str, Any]]]]
A dictionary containing parsed connection string, query, and params (if present).
"""

return {
"connection_string": query_data["connection_string"],
"query": query_data["query"],
"params": query_data.get("params", None)
}
return {"query": query_data["query"], "params": query_data.get("params", None)}


def _read_sql(connection_string: str, query: str, params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> \
typing.Dict[str, pd.DataFrame]:
def _read_sql(engine_obj: engine.Engine,
query: str,
params: typing.Optional[typing.Dict[str, typing.Any]] = None) -> typing.Dict[str, pd.DataFrame]:
"""
Creates a DataFrame from a SQL query.
Parameters
----------
connection_string : str
Connection string to the database.
engine_obj : engine.Engine
SQL engine instance.
query : str
SQL query.
params : Optional[Dict[str, Any]], default=None
Expand All @@ -75,14 +73,10 @@ def _read_sql(connection_string: str, query: str, params: typing.Optional[typing
A dictionary containing a DataFrame of the SQL query result.
"""

# TODO(Devin): PERFORMANCE OPTIMIZATION
# TODO(Devin): Add connection pooling -- Probably needs to go on the actual loader
engine = create_engine(connection_string)

if (params is None):
df = pd.read_sql(query, engine)
df = pd.read_sql(query, engine_obj)
else:
df = pd.read_sql(query, engine, params=params)
df = pd.read_sql(query, engine_obj, params=params)

return {"df": df}

Expand Down Expand Up @@ -132,14 +126,24 @@ def sql_loader(control_message: ControlMessage, task: typing.Dict[str, typing.An

with CMDefaultFailureContextManager(control_message):
final_df = None
engine_registry = {}

sql_config = task["sql_config"]
queries = sql_config["queries"]

for query_data in queries:
aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df)
execution_chain = ExecutionChain(function_chain=[_parse_query_data, _read_sql, aggregate_df])
final_df = execution_chain(query_data=query_data)
try:
for query_data in queries:
conn_str = query_data.pop("connection_string")
if conn_str not in engine_registry:
engine_registry[conn_str] = create_engine(conn_str)

aggregate_df = functools.partial(_aggregate_df, df_aggregate=final_df)
read_sql = functools.partial(_read_sql, engine_obj=engine_registry[conn_str])
execution_chain = ExecutionChain(function_chain=[_parse_query_data, read_sql, aggregate_df])
final_df = execution_chain(query_data=query_data)
finally:
# Dispose all open connections.
for engine_obj in engine_registry.values():
engine_obj.dispose()

control_message.payload(MessageMeta(final_df))

Expand Down
7 changes: 4 additions & 3 deletions morpheus/utils/control_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ def cm_default_failure_context_manager(raise_on_failure: bool = False) -> typing
def decorator(func):

@wraps(func)
def wrapper(control_messsage: ControlMessage, *args, **kwargs):
with CMDefaultFailureContextManager(control_message=control_messsage,
def wrapper(control_message: ControlMessage, *args, **kwargs):
ret_cm = control_message
with CMDefaultFailureContextManager(control_message=control_message,
raise_on_failure=raise_on_failure) as ctx_mgr:
cm_ensure_payload_not_null(control_message=control_messsage)
cm_ensure_payload_not_null(control_message=control_message)
ret_cm = func(ctx_mgr.control_message, *args, **kwargs)

return ret_cm
Expand Down

0 comments on commit 1b6e9f2

Please sign in to comment.