Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored python code to use modern code styles #48

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ default = dremio://dremio:dremio123@localhost:31010/dremio
tag_build =
tag_date = 0

[tool:pytest]
log_cli = 1
log_cli_level = INFO
log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)
log_cli_date_format=%Y-%m-%d %H:%M:%S
119 changes: 57 additions & 62 deletions sqlalchemy_dremio/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy import schema, types, pool
from typing import Optional

from sqlalchemy import schema, types, pool, Table, Connection, Column
from sqlalchemy.engine import default, reflection
from sqlalchemy.engine.interfaces import DBAPICursor, _DBAPISingleExecuteParams, ExecutionContext
from sqlalchemy.sql import compiler

_dialect_name = "dremio"
Expand Down Expand Up @@ -34,7 +37,6 @@
'smallint': types.SMALLINT,
'CHARACTER VARYING': types.VARCHAR,
'ANY': types.VARCHAR,

'ARRAY': types.ARRAY,
'ROW': types.JSON,
'BINARY VARYING': types.LargeBinary,
Expand All @@ -46,45 +48,40 @@ class DremioExecutionContext(default.DefaultExecutionContext):


class DremioCompiler(compiler.SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))

def visit_table(self, table, asfrom=False, **kwargs):
def visit_char_length_func(self, fn, **kw) -> str:
return f'length{self.function_argspec(fn, **kw)}'

def visit_table(self, table: Table, asfrom: bool = False, **kwargs) -> str:
if asfrom:
if table.schema != None and table.schema != "":
fixed_schema = ".".join(["\"" + i.replace('"', '') + "\"" for i in table.schema.split(".")])
fixed_table = fixed_schema + ".\"" + table.name.replace("\"", "") + "\""
if table.schema is not None and table.schema != "":
fixed_schema = ".".join(['"' + i.replace('"', '') + '"' for i in table.schema.split(".")])
fixed_table = fixed_schema + '"."' + table.name.replace('"', "") + '"'
else:
# don't change anything. expect a fully and properly qualified path if no schema is passed.
fixed_table = table.name
# fixed_table = "\"" + table.name.replace("\"", "") + "\""
return fixed_table
else:
return ""

def visit_tablesample(self, tablesample, asfrom=False, **kw):
def visit_tablesample(self, tablesample, asfrom: bool = False, **kw) -> None:
# TODO: This is currently a noop
print(tablesample)


class DremioDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
def get_column_specification(self, column: Column, **kwargs) -> str:
colspec = self.preparer.format_column(column)
colspec += " " + self.dialect.type_compiler.process(column.type)
if column is column.table._autoincrement_column and \
True and \
(
column.default is None or \
isinstance(column.default, schema.Sequence)
):
if column is column.table._autoincrement_column and (
column.default is None or isinstance(column.default, schema.Sequence)
):
colspec += " IDENTITY"
if isinstance(column.default, schema.Sequence) and \
column.default.start > 0:
if isinstance(column.default, schema.Sequence) and column.default.start > 0:
colspec += " " + str(column.default.start)
else:
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
colspec += f" DEFAULT {default}"

if not column.nullable:
colspec += " NOT NULL"
Expand Down Expand Up @@ -141,9 +138,8 @@ class DremioIdentifierPreparer(compiler.IdentifierPreparer):
dremio_unique = dremio_reserved - reserved_words
reserved_words.update(list(dremio_unique))

def __init__(self, dialect):
super(DremioIdentifierPreparer, self). \
__init__(dialect, initial_quote='"', final_quote='"')
def __init__(self, dialect: default.Dialect):
super().__init__(dialect, initial_quote='"', final_quote='"')


class DremioDialect(default.DefaultDialect):
Expand All @@ -169,76 +165,77 @@ def connect(self, *cargs, **cparams):
if 'autocommit' not in engine_params:
cparams['autocommit'] = 1

return self.dbapi.connect(*cargs, **cparams)
return self.loaded_dbapi.connect(*cargs, **cparams)

def last_inserted_ids(self):
return self.context.last_inserted_ids

def get_indexes(self, connection, table_name, schema, **kw):
def get_indexes(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw):
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
def get_pk_constraint(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw):
return []

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
def get_foreign_keys(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw):
return []

def get_columns(self, connection, table_name, schema, **kw):
sql = "DESCRIBE \"{0}\"".format(table_name)
if schema != None and schema != "":
sql = "DESCRIBE \"{0}\".\"{1}\"".format(schema, table_name)
cursor = connection.execute(sql)
result = []
for col in cursor:
cname = col[0]
ctype = _type_map[col[1]]
column = {
"name": cname,
"type": ctype,
"default": None,
"comment": None,
"nullable": True
}
result.append(column)
return (result)
def get_columns(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> list[dict]:
sql = f'DESCRIBE "{table_name}"'
if schema is not None and schema != "":
sql = f'DESCRIBE "{schema}"."{table_name}"'
cursor = connection.exec_driver_sql(sql)

result = [{
"name": col[0],
"type": _type_map[col[1]],
"default": None,
"comment": None,
"nullable": True
} for col in cursor]

return result

@reflection.cache
def get_table_names(self, connection, schema, **kw):
def get_table_names(self, connection: Connection, schema: str, **kw) -> list[str]:
sql = 'SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES"'

# Reverting #5 as Dremio does not support parameterized queries.
if schema is not None:
sql += " WHERE TABLE_SCHEMA = '" + schema + "'"
sql += f" WHERE TABLE_SCHEMA = {schema}"

result = connection.execute(sql)
result = connection.exec_driver_sql(sql)
table_names = [r[0] for r in result]
return table_names

def get_schema_names(self, connection, schema=None, **kw):
def get_schema_names(self, connection: Connection, schema: Optional[str] = None, **kw) -> list[str]:
if len(self.filter_schema_names) > 0:
return self.filter_schema_names

result = connection.execute("SHOW SCHEMAS")
result = connection.exec_driver_sql("SHOW SCHEMAS")
schema_names = [r[0] for r in result]
return schema_names

@reflection.cache
def has_table(self, connection, table_name, schema=None, **kw):
def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw) -> bool:
sql = 'SELECT COUNT(*) FROM INFORMATION_SCHEMA."TABLES"'
sql += " WHERE TABLE_NAME = '" + str(table_name) + "'"
sql += f" WHERE TABLE_NAME = {table_name}"
if schema is not None and schema != "":
sql += " AND TABLE_SCHEMA = '" + str(schema) + "'"
result = connection.execute(sql)
countRows = [r[0] for r in result]
return countRows[0] > 0
sql += f" AND TABLE_SCHEMA = {schema}"
result = connection.exec_driver_sql(sql)
count_rows = [r[0] for r in result]
return count_rows[0] > 0

def get_view_names(self, connection, schema=None, **kwargs):
def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> list[str]:
return []

# Workaround since Dremio does not support parameterized stmts
# Old queries should not have used queries with parameters, since Dremio does not support it
# and these queries failed. If there is no parameter, everything should work as before.
def do_execute(self, cursor, statement, parameters, context):
def do_execute(self,
cursor: DBAPICursor,
statement: str,
parameters: Optional[_DBAPISingleExecuteParams],
context: Optional[ExecutionContext] = None) -> None:
replaced_stmt = statement
for v in parameters:
escaped_str = str(v).replace("'", "''")
Expand All @@ -247,6 +244,4 @@ def do_execute(self, cursor, statement, parameters, context):
else:
replaced_stmt = replaced_stmt.replace('?', "'" + escaped_str + "'", 1)

super(DremioDialect, self).do_execute_no_params(
cursor, replaced_stmt, context
)
super().do_execute_no_params(cursor, replaced_stmt, context)
Loading