Skip to content

Commit

Permalink
Incorporating feedback. Moved server side parameters to Connection an…
Browse files Browse the repository at this point in the history
…d pass to cursor from there.
  • Loading branch information
alarocca-apixio committed May 16, 2023
1 parent 69c299e commit 1bcd4f9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def open(cls, connection):
SessionConnectionWrapper,
)

handle = SessionConnectionWrapper(Connection(), creds.server_side_parameters)
handle = SessionConnectionWrapper(Connection(creds.server_side_parameters))
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
19 changes: 11 additions & 8 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self) -> None:
def __init__(self, server_side_parameters) -> None:
self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.server_side_parameters = server_side_parameters

def __enter__(self) -> Cursor:
return self
Expand Down Expand Up @@ -84,7 +85,7 @@ def close(self) -> None:
self._df = None
self._rows = None

def execute(self, sql: str, server_side_parameters, *parameters: Any) -> None:
def execute(self, sql: str, *parameters: Any) -> None:
"""
Execute a sql statement.
Expand All @@ -108,7 +109,7 @@ def execute(self, sql: str, server_side_parameters, *parameters: Any) -> None:
sql = sql % parameters
builder = SparkSession.builder.enableHiveSupport()

for k, v in server_side_parameters.items():
for k, v in self.server_side_parameters.items():
builder = builder.config(k, v)

spark_session = builder.getOrCreate()
Expand Down Expand Up @@ -164,6 +165,9 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, server_side_parameters) -> None:
self.server_side_parameters = server_side_parameters

def cursor(self) -> Cursor:
"""
Get a cursor.
Expand All @@ -173,15 +177,14 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor()
return Cursor(self.server_side_parameters)


class SessionConnectionWrapper(object):
"""Connection wrapper for the session connection method."""

def __init__(self, handle, server_side_parameters):
def __init__(self, handle):
self.handle = handle
self.server_side_parameters = server_side_parameters
self._cursor = None

def cursor(self):
Expand All @@ -206,10 +209,10 @@ def execute(self, sql, bindings=None):
sql = sql.strip()[:-1]

if bindings is None:
self._cursor.execute(sql, self.server_side_parameters)
self._cursor.execute(sql)
else:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, self.server_side_parameters, *bindings)
self._cursor.execute(sql, *bindings)

@property
def description(self):
Expand Down

0 comments on commit 1bcd4f9

Please sign in to comment.