From 210c7ef7240f3c40af4fc92ee8d81bbf3cd9ce7d Mon Sep 17 00:00:00 2001 From: Dave Gittins Date: Wed, 1 Nov 2023 18:05:17 +0100 Subject: [PATCH] Add support for legacy auth --- sqlalchemy_dremio/db.py | 24 +++++++++++++++++++++--- sqlalchemy_dremio/flight.py | 4 ++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/sqlalchemy_dremio/db.py b/sqlalchemy_dremio/db.py index 7734d23..b350c30 100644 --- a/sqlalchemy_dremio/db.py +++ b/sqlalchemy_dremio/db.py @@ -43,6 +43,21 @@ def d(self, *args, **kwargs): return d +class LegacyAuthHandler(flight.ClientAuthHandler): + def __init__(self, username: str, password: str): + super().__init__() + self.basic_auth = flight.BasicAuth(username, password) + self.token = None + + def authenticate(self, outgoing, incoming): + auth = self.basic_auth.serialize() + outgoing.write(auth) + self.token = incoming.read() + + def get_token(self): + return self.token + + class Connection(object): def __init__(self, connection_string): @@ -79,12 +94,15 @@ def __init__(self, connection_string): client = flight.FlightClient('grpc+{0}://{1}:{2}'.format(protocol, properties['HOST'], properties['PORT']), middleware=[client_cookie_middleware], **connection_args) - + # Authenticate either using basic username/password or using the Token parameter. headers = [] if 'UID' in properties: - bearer_token = client.authenticate_basic_token(properties['UID'], properties['PWD']) - headers.append(bearer_token) + if 'UseLegacyAuth' in properties and properties['UseLegacyAuth'].lower() == 'true': + client.authenticate(LegacyAuthHandler(properties['UID'], properties['PWD'])) + else: + bearer_token = client.authenticate_basic_token(properties['UID'], properties['PWD']) + headers.append(bearer_token) else: headers.append((b'authorization', "Bearer {}".format(properties['Token']).encode('utf-8'))) diff --git a/sqlalchemy_dremio/flight.py b/sqlalchemy_dremio/flight.py index 1db1c91..0a401bf 100644 --- a/sqlalchemy_dremio/flight.py +++ b/sqlalchemy_dremio/flight.py @@ -178,7 +178,7 @@ def create_connect_args(self, url): def add_property(lc_query_dict, property_name, connectors): if property_name.lower() in lc_query_dict: connectors.append('{0}={1}'.format(property_name, lc_query_dict[property_name.lower()])) - + add_property(lc_query_dict, 'UseEncryption', connectors) add_property(lc_query_dict, 'DisableCertificateVerification', connectors) add_property(lc_query_dict, 'TrustedCerts', connectors) @@ -187,6 +187,7 @@ def add_property(lc_query_dict, property_name, connectors): add_property(lc_query_dict, 'quoting', connectors) add_property(lc_query_dict, 'routing_engine', connectors) add_property(lc_query_dict, 'Token', connectors) + add_property(lc_query_dict, 'UseLegacyAuth', connectors) return [[";".join(connectors)], connect_args] @@ -257,4 +258,3 @@ def has_table(self, connection, table_name, schema=None, **kw): def get_view_names(self, connection, schema=None, **kwargs): return [] - \ No newline at end of file