diff --git a/setup.cfg b/setup.cfg index cfb9c9a..5536225 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,3 +21,8 @@ default = dremio://dremio:dremio123@localhost:31010/dremio tag_build = tag_date = 0 +[tool:pytest] +log_cli = 1 +log_cli_level = INFO +log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) +log_cli_date_format=%Y-%m-%d %H:%M:%S diff --git a/sqlalchemy_dremio/base.py b/sqlalchemy_dremio/base.py index 521dbc3..0e42439 100644 --- a/sqlalchemy_dremio/base.py +++ b/sqlalchemy_dremio/base.py @@ -1,5 +1,8 @@ -from sqlalchemy import schema, types, pool +from typing import Optional + +from sqlalchemy import schema, types, pool, Table, Connection, Column from sqlalchemy.engine import default, reflection +from sqlalchemy.engine.interfaces import DBAPICursor, _DBAPISingleExecuteParams, ExecutionContext from sqlalchemy.sql import compiler _dialect_name = "dremio" @@ -34,7 +37,6 @@ 'smallint': types.SMALLINT, 'CHARACTER VARYING': types.VARCHAR, 'ANY': types.VARCHAR, - 'ARRAY': types.ARRAY, 'ROW': types.JSON, 'BINARY VARYING': types.LargeBinary, @@ -46,45 +48,40 @@ class DremioExecutionContext(default.DefaultExecutionContext): class DremioCompiler(compiler.SQLCompiler): - def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) - - def visit_table(self, table, asfrom=False, **kwargs): + def visit_char_length_func(self, fn, **kw) -> str: + return f'length{self.function_argspec(fn, **kw)}' + def visit_table(self, table: Table, asfrom: bool = False, **kwargs) -> str: if asfrom: - if table.schema != None and table.schema != "": - fixed_schema = ".".join(["\"" + i.replace('"', '') + "\"" for i in table.schema.split(".")]) - fixed_table = fixed_schema + ".\"" + table.name.replace("\"", "") + "\"" + if table.schema is not None and table.schema != "": + fixed_schema = ".".join(['"' + i.replace('"', '') + '"' for i in table.schema.split(".")]) + fixed_table = fixed_schema + '"."' + table.name.replace('"', "") + '"' else: # don't change anything. expect a fully and properly qualified path if no schema is passed. fixed_table = table.name - # fixed_table = "\"" + table.name.replace("\"", "") + "\"" return fixed_table else: return "" - def visit_tablesample(self, tablesample, asfrom=False, **kw): + def visit_tablesample(self, tablesample, asfrom: bool = False, **kw) -> None: + # TODO: This is currently a noop print(tablesample) class DremioDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): + def get_column_specification(self, column: Column, **kwargs) -> str: colspec = self.preparer.format_column(column) colspec += " " + self.dialect.type_compiler.process(column.type) - if column is column.table._autoincrement_column and \ - True and \ - ( - column.default is None or \ - isinstance(column.default, schema.Sequence) - ): + if column is column.table._autoincrement_column and ( + column.default is None or isinstance(column.default, schema.Sequence) + ): colspec += " IDENTITY" - if isinstance(column.default, schema.Sequence) and \ - column.default.start > 0: + if isinstance(column.default, schema.Sequence) and column.default.start > 0: colspec += " " + str(column.default.start) else: default = self.get_column_default_string(column) if default is not None: - colspec += " DEFAULT " + default + colspec += f" DEFAULT {default}" if not column.nullable: colspec += " NOT NULL" @@ -141,9 +138,8 @@ class DremioIdentifierPreparer(compiler.IdentifierPreparer): dremio_unique = dremio_reserved - reserved_words reserved_words.update(list(dremio_unique)) - def __init__(self, dialect): - super(DremioIdentifierPreparer, self). \ - __init__(dialect, initial_quote='"', final_quote='"') + def __init__(self, dialect: default.Dialect): + super().__init__(dialect, initial_quote='"', final_quote='"') class DremioDialect(default.DefaultDialect): @@ -169,76 +165,77 @@ def connect(self, *cargs, **cparams): if 'autocommit' not in engine_params: cparams['autocommit'] = 1 - return self.dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) def last_inserted_ids(self): return self.context.last_inserted_ids - def get_indexes(self, connection, table_name, schema, **kw): + def get_indexes(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw): return [] - def get_pk_constraint(self, connection, table_name, schema=None, **kw): + def get_pk_constraint(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw): return [] - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + def get_foreign_keys(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw): return [] - def get_columns(self, connection, table_name, schema, **kw): - sql = "DESCRIBE \"{0}\"".format(table_name) - if schema != None and schema != "": - sql = "DESCRIBE \"{0}\".\"{1}\"".format(schema, table_name) - cursor = connection.execute(sql) - result = [] - for col in cursor: - cname = col[0] - ctype = _type_map[col[1]] - column = { - "name": cname, - "type": ctype, - "default": None, - "comment": None, - "nullable": True - } - result.append(column) - return (result) + def get_columns(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> list[dict]: + sql = f'DESCRIBE "{table_name}"' + if schema is not None and schema != "": + sql = f'DESCRIBE "{schema}"."{table_name}"' + cursor = connection.exec_driver_sql(sql) + + result = [{ + "name": col[0], + "type": _type_map[col[1]], + "default": None, + "comment": None, + "nullable": True + } for col in cursor] + + return result @reflection.cache - def get_table_names(self, connection, schema, **kw): + def get_table_names(self, connection: Connection, schema: str, **kw) -> list[str]: sql = 'SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES"' # Reverting #5 as Dremio does not support parameterized queries. if schema is not None: - sql += " WHERE TABLE_SCHEMA = '" + schema + "'" + sql += f" WHERE TABLE_SCHEMA = {schema}" - result = connection.execute(sql) + result = connection.exec_driver_sql(sql) table_names = [r[0] for r in result] return table_names - def get_schema_names(self, connection, schema=None, **kw): + def get_schema_names(self, connection: Connection, schema: Optional[str] = None, **kw) -> list[str]: if len(self.filter_schema_names) > 0: return self.filter_schema_names - result = connection.execute("SHOW SCHEMAS") + result = connection.exec_driver_sql("SHOW SCHEMAS") schema_names = [r[0] for r in result] return schema_names @reflection.cache - def has_table(self, connection, table_name, schema=None, **kw): + def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> bool: sql = 'SELECT COUNT(*) FROM INFORMATION_SCHEMA."TABLES"' - sql += " WHERE TABLE_NAME = '" + str(table_name) + "'" + sql += f" WHERE TABLE_NAME = {table_name}" if schema is not None and schema != "": - sql += " AND TABLE_SCHEMA = '" + str(schema) + "'" - result = connection.execute(sql) - countRows = [r[0] for r in result] - return countRows[0] > 0 + sql += f" AND TABLE_SCHEMA = {schema}" + result = connection.exec_driver_sql(sql) + count_rows = [r[0] for r in result] + return count_rows[0] > 0 - def get_view_names(self, connection, schema=None, **kwargs): + def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> list[str]: return [] # Workaround since Dremio does not support parameterized stmts # Old queries should not have used queries with parameters, since Dremio does not support it # and these queries failed. If there is no parameter, everything should work as before. - def do_execute(self, cursor, statement, parameters, context): + def do_execute(self, + cursor: DBAPICursor, + statement: str, + parameters: Optional[_DBAPISingleExecuteParams], + context: Optional[ExecutionContext] = None) -> None: replaced_stmt = statement for v in parameters: escaped_str = str(v).replace("'", "''") @@ -247,6 +244,4 @@ def do_execute(self, cursor, statement, parameters, context): else: replaced_stmt = replaced_stmt.replace('?', "'" + escaped_str + "'", 1) - super(DremioDialect, self).do_execute_no_params( - cursor, replaced_stmt, context - ) + super().do_execute_no_params(cursor, replaced_stmt, context) diff --git a/sqlalchemy_dremio/db.py b/sqlalchemy_dremio/db.py index 7734d23..e4736c9 100644 --- a/sqlalchemy_dremio/db.py +++ b/sqlalchemy_dremio/db.py @@ -1,11 +1,9 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import logging +from typing import Optional, Any, Never, Generator, Union from pyarrow import flight +from pyarrow._flight import FlightClient +from sqlalchemy import Executable, types from sqlalchemy_dremio.exceptions import Error, NotSupportedError from sqlalchemy_dremio.flight_middleware import CookieMiddlewareFactory @@ -16,6 +14,10 @@ paramstyle = 'qmark' +class Binary(types.LargeBinary): + __visit_name__ = "VARBINARY" + + def connect(c): return Connection(c) @@ -26,7 +28,7 @@ def check_closed(f): def g(self, *args, **kwargs): if self.closed: raise Error( - '{klass} already closed'.format(klass=self.__class__.__name__)) + f'{self.__class__.__name__} already closed') return f(self, *args, **kwargs) return g @@ -43,9 +45,9 @@ def d(self, *args, **kwargs): return d -class Connection(object): +class Connection: - def __init__(self, connection_string): + def __init__(self, connection_string: str): # Build a map from the connection string supplied using the SQLAlchemy URI # and supplied properties. The format is generated from DremioDialect_flight.create_connect_args() @@ -55,7 +57,7 @@ def __init__(self, connection_string): splits = connection_string.split(";") for kvpair in splits: - kv = kvpair.split("=",1) + kv = kvpair.split("=", 1) properties[kv[0]] = kv[1] connection_args = {} @@ -68,32 +70,33 @@ def __init__(self, connection_string): # Specify the trusted certificates connection_args['disable_server_verification'] = False if 'TrustedCerts' in properties: - with open(properties['TrustedCerts'] , "rb") as root_certs: + with open(properties['TrustedCerts'], "rb") as root_certs: connection_args["tls_root_certs"] = root_certs.read() # Or disable server verification entirely - elif 'DisableCertificateVerification' in properties and properties['DisableCertificateVerification'].lower() == 'true': + elif 'DisableCertificateVerification' in properties and properties[ + 'DisableCertificateVerification'].lower() == 'true': connection_args['disable_server_verification'] = True # Enabling cookie middleware for stateful connectivity. client_cookie_middleware = CookieMiddlewareFactory() - client = flight.FlightClient('grpc+{0}://{1}:{2}'.format(protocol, properties['HOST'], properties['PORT']), - middleware=[client_cookie_middleware], **connection_args) - + client = flight.FlightClient(f'grpc+{protocol}://{properties["HOST"]}:{properties["PORT"]}', + middleware=[client_cookie_middleware], **connection_args) + # Authenticate either using basic username/password or using the Token parameter. headers = [] if 'UID' in properties: bearer_token = client.authenticate_basic_token(properties['UID'], properties['PWD']) headers.append(bearer_token) else: - headers.append((b'authorization', "Bearer {}".format(properties['Token']).encode('utf-8'))) + headers.append((b'authorization', f"Bearer {properties["Token"]}".encode('utf-8'))) # Propagate Dremio-specific headers. - def add_header(properties, headers, header_name): + def add_header(properties: dict, headers: list[tuple[bytes, bytes]], header_name: str) -> None: if header_name in properties: headers.append((header_name.lower().encode('utf-8'), properties[header_name].encode('utf-8'))) - add_header(properties, headers, 'Schema') + add_header(properties, headers, 'schema') add_header(properties, headers, 'routing_queue') add_header(properties, headers, 'routing_tag') add_header(properties, headers, 'quoting') @@ -106,11 +109,11 @@ def add_header(properties, headers, header_name): self.cursors = [] @check_closed - def rollback(self): + def rollback(self) -> None: pass @check_closed - def close(self): + def close(self) -> None: """Close the connection now.""" self.closed = True for cursor in self.cursors: @@ -120,11 +123,11 @@ def close(self): pass # already closed @check_closed - def commit(self): + def commit(self) -> None: pass @check_closed - def cursor(self): + def cursor(self) -> "Cursor": """Return a new Cursor Object using the connection.""" cursor = Cursor(self.flightclient, self.options) self.cursors.append(cursor) @@ -132,22 +135,22 @@ def cursor(self): return cursor @check_closed - def execute(self, query): + def execute(self, query: Executable) -> "Cursor": cursor = self.cursor() return cursor.execute(query) - def __enter__(self): + def __enter__(self) -> "Connection": return self - def __exit__(self, *exc): + def __exit__(self, *exc) -> None: self.commit() # no-op self.close() -class Cursor(object): +class Cursor: """Connection cursor.""" - def __init__(self, flightclient=None, options=None): + def __init__(self, flightclient: Optional[FlightClient] = None, options: Optional[Any] = None): self.flightclient = flightclient self.options = options @@ -167,29 +170,33 @@ def __init__(self, flightclient=None, options=None): @property @check_result @check_closed - def rowcount(self): + def rowcount(self) -> int: return len(self._results) @check_closed - def close(self): + def close(self) -> None: """Close the cursor.""" self.closed = True @check_closed - def execute(self, query, params=None): + def execute(self, query: Union[Executable, str], params: Optional[tuple[Any, ...]] = None) -> "Cursor": self.description = None - self._results, self.description = execute( - query, self.flightclient, self.options) + if params is not None: + for param in params: + if isinstance(param, str): + param = f"'{param}'" + query = query.replace('?', str(param), 1) + self._results, self.description = execute(query, self.flightclient, self.options) return self @check_closed - def executemany(self, query): + def executemany(self, query: str) -> Never: raise NotSupportedError( '`executemany` is not supported, use `execute` instead') @check_result @check_closed - def fetchone(self): + def fetchone(self) -> Optional[Any]: """ Fetch the next row of a query result set, returning a single sequence, or `None` when no more data is available. @@ -201,7 +208,7 @@ def fetchone(self): @check_result @check_closed - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> list[tuple[Any, ...]]: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is returned when @@ -214,7 +221,7 @@ def fetchmany(self, size=None): @check_result @check_closed - def fetchall(self): + def fetchall(self) -> list[tuple[Any, ...]]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples). Note that the cursor's @@ -225,15 +232,15 @@ def fetchall(self): return out @check_closed - def setinputsizes(self, sizes): + def setinputsizes(self, sizes) -> None: # not supported - pass + raise NotSupportedError() @check_closed - def setoutputsizes(self, sizes): + def setoutputsizes(self, sizes) -> None: # not supported - pass + raise NotSupportedError() @check_closed - def __iter__(self): + def __iter__(self) -> Generator[tuple[Any, ...], None, None]: return iter(self._results) diff --git a/sqlalchemy_dremio/exceptions.py b/sqlalchemy_dremio/exceptions.py index bfb8cb3..70116d1 100644 --- a/sqlalchemy_dremio/exceptions.py +++ b/sqlalchemy_dremio/exceptions.py @@ -1,9 +1,3 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - - class Error(Exception): pass diff --git a/sqlalchemy_dremio/flight.py b/sqlalchemy_dremio/flight.py index 1db1c91..f76a1b9 100644 --- a/sqlalchemy_dremio/flight.py +++ b/sqlalchemy_dremio/flight.py @@ -1,9 +1,14 @@ -import re +from types import ModuleType +from typing import Optional -from sqlalchemy import schema, types, pool +from sqlalchemy import schema, types, pool, Table, URL, Connection from sqlalchemy.engine import default, reflection +from sqlalchemy.engine.interfaces import ReflectedIndex, ReflectedPrimaryKeyConstraint, ReflectedForeignKeyConstraint, \ + ReflectedColumn from sqlalchemy.sql import compiler +from sqlalchemy_dremio import exceptions + _dialect_name = "dremio+flight" _type_map = { @@ -43,23 +48,25 @@ class DremioExecutionContext(default.DefaultExecutionContext): pass + + class DremioCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) - - def visit_table(self, table, asfrom=False, **kwargs): + return f'length{self.function_argspec(fn, **kw)}' + def visit_table(self, table: Table, asfrom: bool = False, **kwargs) -> str: if asfrom: - if table.schema != None and table.schema != "": - fixed_schema = ".".join(["\"" + i.replace('"', '') + "\"" for i in table.schema.split(".")]) - fixed_table = fixed_schema + ".\"" + table.name.replace("\"", "") + "\"" + if table.schema is not None and table.schema != "": + fixed_schema = ".".join([f'"{i.replace('"', '')}"' for i in table.schema.split(".")]) + fixed_table = f'{fixed_schema}."{table.name.replace('"', "")}"' else: - fixed_table = "\"" + table.name.replace("\"", "") + "\"" + fixed_table = f'"{table.name.replace('"', "")}"' return fixed_table else: return "" def visit_tablesample(self, tablesample, asfrom=False, **kw): + # TODO: This is currently a noop print(tablesample) @@ -67,15 +74,11 @@ class DremioDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) colspec += " " + self.dialect.type_compiler.process(column.type) - if column is column.table._autoincrement_column and \ - True and \ - ( - column.default is None or \ - isinstance(column.default, schema.Sequence) - ): + if column is column.table._autoincrement_column and ( + column.default is None or isinstance(column.default, schema.Sequence) + ): colspec += " IDENTITY" - if isinstance(column.default, schema.Sequence) and \ - column.default.start > 0: + if isinstance(column.default, schema.Sequence) and column.default.start > 0: colspec += " " + str(column.default.start) else: default = self.get_column_default_string(column) @@ -138,8 +141,7 @@ class DremioIdentifierPreparer(compiler.IdentifierPreparer): reserved_words.update(list(dremio_unique)) def __init__(self, dialect): - super(DremioIdentifierPreparer, self). \ - __init__(dialect, initial_quote='"', final_quote='"') + super().__init__(dialect, initial_quote='"', final_quote='"') class DremioExecutionContext_flight(DremioExecutionContext): @@ -147,7 +149,6 @@ class DremioExecutionContext_flight(DremioExecutionContext): class DremioDialect_flight(default.DefaultDialect): - name = _dialect_name driver = _dialect_name supports_sane_rowcount = False @@ -159,26 +160,26 @@ class DremioDialect_flight(default.DefaultDialect): preparer = DremioIdentifierPreparer execution_ctx_cls = DremioExecutionContext - def create_connect_args(self, url): + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username='user') connect_args = {} - connectors = ['HOST=%s' % opts['host'], - 'PORT=%s' % opts['port']] + connectors = [f'HOST={opts["host"]}', + f'PORT={opts["port"]}'] if 'user' in opts: - connectors.append('{0}={1}'.format('UID', opts['user'])) - connectors.append('{0}={1}'.format('PWD', opts['password'])) + connectors.append(f'UID={opts["user"]}') + connectors.append(f'PWD={opts['password']}') if 'database' in opts: - connectors.append('{0}={1}'.format('Schema', opts['database'])) + connectors.append(f'schema={opts["database"]}') # Clone the query dictionary with lower-case keys. lc_query_dict = {k.lower(): v for k, v in url.query.items()} - def add_property(lc_query_dict, property_name, connectors): + def add_property(lc_query_dict: dict, property_name: str, connectors: list[str]): if property_name.lower() in lc_query_dict: - connectors.append('{0}={1}'.format(property_name, lc_query_dict[property_name.lower()])) - + connectors.append(f'{property_name}={lc_query_dict[property_name.lower()]}') + add_property(lc_query_dict, 'UseEncryption', connectors) add_property(lc_query_dict, 'DisableCertificateVerification', connectors) add_property(lc_query_dict, 'TrustedCerts', connectors) @@ -191,70 +192,69 @@ def add_property(lc_query_dict, property_name, connectors): return [[";".join(connectors)], connect_args] @classmethod - def dbapi(cls): + def dbapi(cls) -> ModuleType: import sqlalchemy_dremio.db as module return module - def connect(self, *cargs, **cparams): - return self.dbapi.connect(*cargs, **cparams) + def connect(self, *cargs, **cparams) -> Connection: + return self.loaded_dbapi.connect(*cargs, **cparams) def last_inserted_ids(self): return self.context.last_inserted_ids - def get_indexes(self, connection, table_name, schema, **kw): - return [] - - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - return [] - - def get_foreign_keys(self, connection, table_name, schema=None, **kw): - return [] - - def get_columns(self, connection, table_name, schema, **kw): - sql = "DESCRIBE \"{0}\"".format(table_name) - if schema != None and schema != "": - sql = "DESCRIBE \"{0}\".\"{1}\"".format(schema, table_name) - cursor = connection.execute(sql) - result = [] - for col in cursor: - cname = col[0] - ctype = _type_map[col[1]] - column = { - "name": cname, - "type": ctype, - "default": None, - "comment": None, - "nullable": True - } - result.append(column) - return (result) + def get_indexes(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> list[ + ReflectedIndex]: + raise exceptions.NotSupportedError() + + def get_pk_constraint(self, connection: Connection, table_name: str, schem: Optional[str] = None, + **kw) -> ReflectedPrimaryKeyConstraint: + raise exceptions.NotSupportedError() + + def get_foreign_keys(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> list[ + ReflectedForeignKeyConstraint]: + raise exceptions.NotSupportedError() + + def get_columns(self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw) -> list[ReflectedColumn]: + sql = f'DESCRIBE "{table_name}"' + if schema is not None and schema != "": + sql = f'DESCRIBE "{schema}"."{table_name}"' + cursor = connection.exec_driver_sql(sql) + result = [ReflectedColumn(name=col[0], + type=_type_map[col[1]], + default=None, + comment=None, + nullable=True) + for col in cursor] + return result @reflection.cache - def get_table_names(self, connection, schema, **kw): + def get_table_names(self, connection: Connection, schema: Optional[str] = None, **kw) -> list[str]: sql = 'SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES"' if schema is not None: - sql = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.\"TABLES\" WHERE TABLE_SCHEMA = '" + schema + "'" + sql = f'SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES" WHERE TABLE_SCHEMA = "{schema}"' - result = connection.execute(sql) + result = connection.exec_driver_sql(sql) table_names = [r[0] for r in result] return table_names - def get_schema_names(self, connection, schema=None, **kw): - result = connection.execute("SHOW SCHEMAS") + def get_schema_names(self, connection: Connection, schema: Optional[str] = None, **kw) -> list[str]: + result = connection.exec_driver_sql("SHOW SCHEMAS") schema_names = [r[0] for r in result] return schema_names - @reflection.cache - def has_table(self, connection, table_name, schema=None, **kw): + def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> bool: sql = 'SELECT COUNT(*) FROM INFORMATION_SCHEMA."TABLES"' - sql += " WHERE TABLE_NAME = '" + str(table_name) + "'" + sql += f" WHERE TABLE_NAME = '{table_name}'" if schema is not None and schema != "": - sql += " AND TABLE_SCHEMA = '" + str(schema) + "'" - result = connection.execute(sql) - countRows = [r[0] for r in result] - return countRows[0] > 0 - - def get_view_names(self, connection, schema=None, **kwargs): - return [] - \ No newline at end of file + sql += f" AND TABLE_SCHEMA = '{schema}'" + result = connection.exec_driver_sql(sql) + count_rows = [r[0] for r in result] + return count_rows[0] > 0 + + def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> list[str]: + raise exceptions.NotSupportedError() diff --git a/sqlalchemy_dremio/flight_middleware.py b/sqlalchemy_dremio/flight_middleware.py index 786b647..2b3b470 100644 --- a/sqlalchemy_dremio/flight_middleware.py +++ b/sqlalchemy_dremio/flight_middleware.py @@ -1,17 +1,21 @@ +from http.cookies import SimpleCookie + +from pyarrow._flight import CallInfo from pyarrow.flight import ClientMiddleware from pyarrow.flight import ClientMiddlewareFactory -from http.cookies import SimpleCookie class CookieMiddlewareFactory(ClientMiddlewareFactory): """A factory that creates CookieMiddleware(s).""" - def __init__(self): + def __init__(self, *args, **kwargs) -> None: self.cookies = {} + super().__init__(*args, **kwargs) - def start_call(self, info): + def start_call(self, info: CallInfo) -> "CookieMiddleware": return CookieMiddleware(self) - + + class CookieMiddleware(ClientMiddleware): """ A ClientMiddleware that receives and retransmits cookies. @@ -22,10 +26,11 @@ class CookieMiddleware(ClientMiddleware): The factory containing the currently cached cookies. """ - def __init__(self, factory): + def __init__(self, factory: CookieMiddlewareFactory, *args, **kwargs) -> None: self.factory = factory + super().__init__(*args, **kwargs) - def received_headers(self, headers): + def received_headers(self, headers: dict[str, str]) -> None: for key in headers: if key.lower() == 'set-cookie': cookie = SimpleCookie() @@ -34,8 +39,8 @@ def received_headers(self, headers): self.factory.cookies.update(cookie.items()) - def sending_headers(self): + def sending_headers(self) -> dict[bytes, bytes]: if self.factory.cookies: - cookie_string = '; '.join("{!s}={!s}".format(key, val.value) for (key, val) in self.factory.cookies.items()) + cookie_string = '; '.join(f"{key}={val.value}" for key, val in self.factory.cookies.items()) return {b'cookie': cookie_string.encode('utf-8')} return {} diff --git a/sqlalchemy_dremio/query.py b/sqlalchemy_dremio/query.py index b58590d..2807659 100644 --- a/sqlalchemy_dremio/query.py +++ b/sqlalchemy_dremio/query.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from sqlalchemy import types +from typing import Optional, Any +import pandas as pd import pyarrow as pa from pyarrow import flight +from pyarrow._flight import FlightClient, FlightCallOptions +from sqlalchemy import types _type_map = { 'boolean': types.BOOLEAN, @@ -29,6 +27,7 @@ 'time': types.TIME, 'TIME': types.TIME, 'datetime64[ns]': types.DATETIME, + 'datetime64[ms]': types.DATETIME, 'timestamp': types.TIMESTAMP, 'TIMESTAMP': types.TIMESTAMP, 'varchar': types.VARCHAR, @@ -39,10 +38,10 @@ } -def run_query(query, flightclient=None, options=None): +def run_query(query: str, flightclient: Optional[FlightClient] = None, + options: Optional[FlightCallOptions] = None) -> pd.DataFrame: info = flightclient.get_flight_info(flight.FlightDescriptor.for_command(query), options) reader = flightclient.do_get(info.endpoints[0].ticket, options) - batches = [] while True: try: @@ -52,12 +51,15 @@ def run_query(query, flightclient=None, options=None): break data = pa.Table.from_batches(batches) + # TODO: Make pandas an optional dependency df = data.to_pandas() return df -def execute(query, flightclient=None, options=None): +def execute(query: str, + flightclient: Optional[FlightClient] = None, + options: Optional[FlightCallOptions] = None) -> tuple[list[Any], list[tuple]]: df = run_query(query, flightclient, options) result = [] diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index e4628e9..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pathlib import Path - -import pytest -from sqlalchemy import create_engine -import sqlalchemy.dialects - -sqlalchemy.dialects.registry.register("dremio", "sqlalchemy_dremio.flight", "DremioDialect_flight") - - -def help(): - print("""Connection string must be set as env variable, - for example: - Windows: setx DREMIO_CONNECTION_URL "dremio+flight://dremio:dremio123@localhost:32010/dremio" - Linux: export DREMIO_CONNECTION_URL="dremio+flight://dremio:dremio123@localhost:32010/dremio" - """) - - -def get_engine(): - """ - Creates a connection using the parameters defined in ODBC connect string - """ - if not os.environ['DREMIO_CONNECTION_URL']: - help() - return - return create_engine(os.environ['DREMIO_CONNECTION_URL']) - - -@pytest.fixture(scope='session', autouse=True) -def init_test_schema(request): - test_sql = Path("scripts/sample.sql") - get_engine().execute(open(test_sql).read()) - - def fin(): - get_engine().execute('DROP TABLE $scratch.sqlalchemy_tests') - - request.addfinalizer(fin) diff --git a/test/pytest.ini b/test/pytest.ini deleted file mode 100644 index f5276a2..0000000 --- a/test/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -log_cli = 1 -log_cli_level = INFO -log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) -log_cli_date_format=%Y-%m-%d %H:%M:%S diff --git a/test/test_dremio.py b/test/test_dremio.py deleted file mode 100644 index f19783f..0000000 --- a/test/test_dremio.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -import logging -import pytest -from sqlalchemy import MetaData -from sqlalchemy.testing.schema import Table - -LOGGER = logging.getLogger(__name__) - -from . import conftest - -FULL_ROWS = """[(1, 81297389127389213, Decimal('9112.229'), 9192.921875, 9292.17272, 'ZZZZZZZZZZZZ', b'AAAAAAAAA', datetime.datetime(2020, 4, 5, 15, 8, 39, 574000), datetime.date(2020, 4, 5) -, datetime.time(12, 19, 1), True), (2, 812123489127389213, Decimal('6782.229'), 2234193.0, 9122922.17272, 'BBBBBBBBB', b'CCCCCCCCCCCC', datetime.datetime(2020, 4, 5, 15, 8, 39, 574000), datetime.date(2022, -4, 5), datetime.time(10, 19, 1), False)]""" - - -@pytest.fixture(scope='session') -def engine(): - conftest.get_engine() - return engine - - -def test_connect_args(): - """ - Tests connect string - """ - engine = conftest.get_engine() - try: - results = engine.execute('select version from sys.version').fetchone() - assert results is not None - finally: - engine.dispose() - - -def test_simple_sql(): - result = conftest.get_engine().execute('show databases') - rows = [row for row in result] - assert len(rows) >= 0, 'show database results' - - -def test_row_count(engine): - rows = conftest.get_engine().execute('SELECT * FROM $scratch.sqlalchemy_tests').fetchall() - assert len(rows) is 2 - -def test_has_table_True(): - assert conftest.get_engine().has_table("version", schema = "sys") - -def test_has_table_True2(): - assert conftest.get_engine().has_table("version") - -def test_has_table_False(): - assert not conftest.get_engine().has_table("does_not_exist", schema = "sys") - -def test_has_table_False2(): - assert not conftest.get_engine().has_table("does_not_exist") - diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..456b815 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +import os +from pathlib import Path + +import pytest +import sqlalchemy as sa +import sqlalchemy.dialects +from pytest import FixtureRequest + + +@pytest.fixture(scope="session") +def db_url() -> str: + default_url = "dremio+flight://dremio:dremio123@localhost:32010?UseEncryption=false" + return os.environ.get("DREMIO_CONNECTION_URL", default_url) + + +@pytest.fixture(scope="session") +def engine(db_url: str) -> sa.Engine: + sqlalchemy.dialects.registry.register("dremio", "sqlalchemy_dremio.flight", "DremioDialect_flight") + return sa.create_engine(db_url) + + +@pytest.fixture(scope="session") +def conn(engine: sa.Engine) -> sa.Connection: + """ + Creates a connection using the parameters defined in ODBC connect string + """ + with engine.connect() as conn: + yield conn + engine.dispose() + + +@pytest.fixture(scope='session', autouse=True) +def db(request: FixtureRequest, conn: sa.Connection) -> sa.Connection: + test_sql = Path("scripts/sample.sql") + sql = test_sql.read_text() + conn.execute(sa.text(sql)) + + def fin(): + conn.execute(sa.text('DROP TABLE $scratch.sqlalchemy_tests')) + + request.addfinalizer(fin) + return conn diff --git a/scripts/sample.sql b/tests/scripts/sample.sql similarity index 100% rename from scripts/sample.sql rename to tests/scripts/sample.sql diff --git a/tests/test_dremio.py b/tests/test_dremio.py new file mode 100644 index 0000000..6d87f2d --- /dev/null +++ b/tests/test_dremio.py @@ -0,0 +1,104 @@ +import datetime +from typing import Any + +import pytest +import sqlalchemy as sa +from sqlalchemy import Connection + +@pytest.fixture +def test_table() -> sa.Table: + meta = sa.MetaData(schema="$scratch") + return sa.Table("sqlalchemy_tests", meta, + sa.Column("int_col", sa.Integer), + sa.Column("bigint_col", sa.BIGINT), + sa.Column("decimal_col", sa.DECIMAL), + sa.Column("float_col", sa.FLOAT), + sa.Column("double_col", sa.DOUBLE), + sa.Column("varchar_col", sa.VARCHAR), + sa.Column("binary_col", sa.VARBINARY), + sa.Column("timestamp_col", sa.TIMESTAMP), + sa.Column("date_col", sa.DATE), + sa.Column("time_col", sa.TIME), + sa.Column("bool_col", sa.BOOLEAN), + ) + +@pytest.mark.parametrize("param,column_name,expected", [ + (1, "int_col", 1), + ("ZZZZZZZZZZZZ", "varchar_col", 1) +]) +def test_can_use_where_statement(test_table: sa.Table, db: Connection,column_name: str, param: Any, expected: Any): + sql = sa.select(test_table.c.int_col).where(getattr(test_table.c, column_name) == param) + result = db.execute(sql).scalar_one() + assert result == 1 + + +@pytest.mark.parametrize("subquery_type", ["cte", "subquery"]) +def test_cte_works_correctly(test_table: sa.Table, db: Connection, subquery_type: str): + sql = sa.select(test_table.c.int_col.label('my_id')).where(test_table.c.int_col == 1) + if subquery_type == "cte": + sql = sql.cte("my_cte") + elif subquery_type == "subquery": + sql = sql.subquery("my_subquery") + else: + raise ValueError("subquery_type must be 'cte' or 'subquery'") + joins = sql.join(test_table, onclause=test_table.c.int_col == sql.c.my_id) + sql = sa.select(sql.c.my_id, test_table.c.int_col).select_from(joins) + + results = db.execute(sql).fetchone() + assert results == (1, 1) + + +def test_can_select_from_dremio(test_table: sa.Table, db: Connection): + sql = sa.select(test_table) + results = db.execute(sql).fetchall() + assert len(results) == 2 + + +@pytest.mark.xfail(reason="Need to implement Dremio type hierarchy to handle Binary encoding") +def test_can_insert_into_dremio(test_table: sa.Table, db: Connection): + sql = sa.insert(test_table).returning() + results = db.execute(sql, {"int_col": 3, + "bigint_col": 81297389127389214, + "decimal_col": 5334.532, + "float_col": 9234.929, + "double_col": 1234.1234, + "varchar_col": 'xxxxxxx', + "binary_col": 'yyyyy'.encode('utf-8'), + "timestamp_col": datetime.datetime.now().timestamp(), + "date_col": datetime.date.today(), + "time_col": datetime.time(12, 19, 1), + "bool_col": True, + }) + assert results.rowcount == 1 + sql = sa.select(test_table).where(test_table.c.int_col == 3) + assert db.execute(sql).fetchone().int_col == 3 + + +def test_connect_args(conn: Connection): + """ + Tests connect string + """ + + results = conn.execute(sa.text('select version from sys.version')).fetchone() + assert results is not None + + +def test_simple_sql(db: Connection): + result = db.execute(sa.text('show databases')) + rows = [row for row in result] + assert len(rows) >= 0, 'show database results' + + +def test_row_count(db: Connection): + rows = db.execute(sa.text('SELECT * FROM $scratch.sqlalchemy_tests')).fetchall() + assert len(rows) is 2 + + +@pytest.mark.parametrize('table_name, schema, exists', [ + ('version', 'sys', True), + ('version', None, True), + ("does_not_exist", "sys", False), + ("does_not_exist", None, False), +]) +def test_has_table(engine: sa.Engine, db: Connection, table_name: str, schema: str, exists: bool): + assert engine.dialect.has_table(db, table_name, schema=schema) is exists