Skip to content

Commit

Permalink
Added server side parameters for thrift connection type
Browse files Browse the repository at this point in the history
  • Loading branch information
hanna-liashchuk committed Jan 19, 2023
1 parent 89f0cc2 commit e8bc6cd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20221229-200956.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Support server side parameters in thrift connection
time: 2022-12-29T20:09:56.457776+02:00
custom:
Author: ' hanna-liashchuk'
Issue: "387"
PR: "577"
6 changes: 5 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def open(cls, connection):
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
)
conn = hive.connect(thrift_transport=transport)
conn = hive.connect(
thrift_transport=transport,
configuration=creds.server_side_parameters
)
else:
conn = hive.connect(
host=creds.host,
Expand All @@ -388,6 +391,7 @@ def open(cls, connection):
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
configuration=creds.server_side_parameters
) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,14 @@ def test_thrift_connection(self):
config = self._get_target_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password):
def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password, configuration):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, 'dbt')
self.assertIsNone(auth)
self.assertIsNone(kerberos_service_name)
self.assertIsNone(password)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
Expand All @@ -175,11 +176,12 @@ def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
def hive_thrift_connect(thrift_transport, configuration):
self.assertIsNotNone(thrift_transport)
transport = thrift_transport._trans
self.assertEqual(transport.host, 'myorg.sparkhost.com')
self.assertEqual(transport.port, 10001)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
Expand All @@ -194,13 +196,14 @@ def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password):
def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password, configuration):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, 'dbt')
self.assertEqual(auth, 'KERBEROS')
self.assertEqual(kerberos_service_name, 'hive')
self.assertIsNone(password)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
Expand Down Expand Up @@ -734,4 +737,3 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel
'stats:rows:label': 'rows',
'stats:rows:value': 12345678
})

0 comments on commit e8bc6cd

Please sign in to comment.