diff --git a/.changes/unreleased/Features-20221229-200956.yaml b/.changes/unreleased/Features-20221229-200956.yaml new file mode 100644 index 000000000..1add9bf72 --- /dev/null +++ b/.changes/unreleased/Features-20221229-200956.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Support server side parameters in thrift connection +time: 2022-12-29T20:09:56.457776+02:00 +custom: + Author: ' hanna-liashchuk' + Issue: "387" + PR: "577" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 5756aba3c..a939ae753 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -439,7 +439,10 @@ def open(cls, connection: Connection) -> Connection: kerberos_service_name=creds.kerberos_service_name, password=creds.password, ) - conn = hive.connect(thrift_transport=transport) + conn = hive.connect( + thrift_transport=transport, + configuration=creds.server_side_parameters, + ) else: conn = hive.connect( host=creds.host, @@ -448,6 +451,7 @@ def open(cls, connection: Connection) -> Connection: auth=creds.auth, kerberos_service_name=creds.kerberos_service_name, password=creds.password, + configuration=creds.server_side_parameters, ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 1eb818241..a7da63301 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -173,13 +173,16 @@ def test_thrift_connection(self): config = self._get_target_thrift(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + def hive_thrift_connect( + host, port, username, auth, kerberos_service_name, password, configuration + ): self.assertEqual(host, "myorg.sparkhost.com") self.assertEqual(port, 10001) self.assertEqual(username, "dbt") self.assertIsNone(auth) self.assertIsNone(kerberos_service_name) self.assertIsNone(password) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -194,11 +197,12 @@ def test_thrift_ssl_connection(self): config = self._get_target_use_ssl_thrift(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(thrift_transport): + def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) transport = thrift_transport._trans self.assertEqual(transport.host, "myorg.sparkhost.com") self.assertEqual(transport.port, 10001) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -213,13 +217,16 @@ def test_thrift_connection_kerberos(self): config = self._get_target_thrift_kerberos(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + def hive_thrift_connect( + host, port, username, auth, kerberos_service_name, password, configuration + ): self.assertEqual(host, "myorg.sparkhost.com") self.assertEqual(port, 10001) self.assertEqual(username, "dbt") self.assertEqual(auth, "KERBEROS") self.assertEqual(kerberos_service_name, "hive") self.assertIsNone(password) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -710,6 +717,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel config = self._get_target_http(self.project_cfg) columns = SparkAdapter(config).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) + self.assertEqual( columns[2].to_column_dict(omit_none=False), {