Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added server side parameters for thrift connection type #577

Merged
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20221229-200956.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Support server side parameters in thrift connection
time: 2022-12-29T20:09:56.457776+02:00
custom:
Author: ' hanna-liashchuk'
Issue: "387"
PR: "577"
6 changes: 5 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def open(cls, connection: Connection) -> Connection:
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
)
conn = hive.connect(thrift_transport=transport)
conn = hive.connect(
thrift_transport=transport,
configuration=creds.server_side_parameters,
)
else:
conn = hive.connect(
host=creds.host,
Expand All @@ -448,6 +451,7 @@ def open(cls, connection: Connection) -> Connection:
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
configuration=creds.server_side_parameters,
) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
Expand Down
109 changes: 56 additions & 53 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,15 @@ def test_thrift_connection(self):
config = self._get_target_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password):
self.assertEqual(host, "myorg.sparkhost.com")

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password, configuration):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, "dbt")
self.assertIsNone(auth)
self.assertIsNone(kerberos_service_name)
self.assertIsNone(password)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, "connect", new=hive_thrift_connect):
connection = adapter.acquire_connection("dummy")
Expand All @@ -194,11 +196,12 @@ def test_thrift_ssl_connection(self):
config = self._get_target_use_ssl_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(thrift_transport):
def hive_thrift_connect(thrift_transport, configuration):
self.assertIsNotNone(thrift_transport)
transport = thrift_transport._trans
self.assertEqual(transport.host, "myorg.sparkhost.com")
self.assertEqual(transport.port, 10001)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, "connect", new=hive_thrift_connect):
connection = adapter.acquire_connection("dummy")
Expand All @@ -213,13 +216,14 @@ def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password):
self.assertEqual(host, "myorg.sparkhost.com")
def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password, configuration):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, "dbt")
self.assertEqual(auth, "KERBEROS")
self.assertEqual(kerberos_service_name, "hive")
self.assertIsNone(password)
self.assertDictEqual(configuration, {})

with mock.patch.object(hive, "connect", new=hive_thrift_connect):
connection = adapter.acquire_connection("dummy")
Expand Down Expand Up @@ -710,52 +714,51 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel
config = self._get_target_http(self.project_cfg)
columns = SparkAdapter(config).parse_columns_from_information(relation)
self.assertEqual(len(columns), 4)
self.assertEqual(
columns[2].to_column_dict(omit_none=False),
{
"table_database": None,
"table_schema": relation.schema,
"table_name": relation.name,
"table_type": rel_type,
"table_owner": "root",
"column": "dt",
"column_index": 2,
"dtype": "date",
"numeric_scale": None,
"numeric_precision": None,
"char_size": None,
"stats:bytes:description": "",
"stats:bytes:include": True,
"stats:bytes:label": "bytes",
"stats:bytes:value": 1234567890,
"stats:rows:description": "",
"stats:rows:include": True,
"stats:rows:label": "rows",
"stats:rows:value": 12345678,
},
)

self.assertEqual(
columns[3].to_column_dict(omit_none=False),
{
"table_database": None,
"table_schema": relation.schema,
"table_name": relation.name,
"table_type": rel_type,
"table_owner": "root",
"column": "struct_col",
"column_index": 3,
"dtype": "struct",
"numeric_scale": None,
"numeric_precision": None,
"char_size": None,
"stats:bytes:description": "",
"stats:bytes:include": True,
"stats:bytes:label": "bytes",
"stats:bytes:value": 1234567890,
"stats:rows:description": "",
"stats:rows:include": True,
"stats:rows:label": "rows",
"stats:rows:value": 12345678,
},
)
self.assertEqual(columns[2].to_column_dict(omit_none=False), {
'table_database': None,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_owner': 'root',
'column': 'dt',
'column_index': 2,
'dtype': 'date',
'numeric_scale': None,
'numeric_precision': None,
'char_size': None,

'stats:bytes:description': '',
'stats:bytes:include': True,
'stats:bytes:label': 'bytes',
'stats:bytes:value': 1234567890,

'stats:rows:description': '',
'stats:rows:include': True,
'stats:rows:label': 'rows',
'stats:rows:value': 12345678
})

self.assertEqual(columns[3].to_column_dict(omit_none=False), {
'table_database': None,
'table_schema': relation.schema,
'table_name': relation.name,
'table_type': rel_type,
'table_owner': 'root',
'column': 'struct_col',
'column_index': 3,
'dtype': 'struct',
'numeric_scale': None,
'numeric_precision': None,
'char_size': None,

'stats:bytes:description': '',
'stats:bytes:include': True,
'stats:bytes:label': 'bytes',
'stats:bytes:value': 1234567890,

'stats:rows:description': '',
'stats:rows:include': True,
'stats:rows:label': 'rows',
'stats:rows:value': 12345678
})