Skip to content

Commit

Permalink
Rails 5 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kddnewton committed Jan 24, 2017
1 parent 0d8eefe commit 92eb12c
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ source 'https://rubygems.org'

gemspec

gem 'activerecord', '4.2.7.1'
gem 'activerecord', '5.0.1'
gem 'pry', '0.10.4'
9 changes: 4 additions & 5 deletions lib/active_record/connection_adapters/odbc_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def initialize(connection, logger, dbms)
super(connection, logger)
@connection = connection
@dbms = dbms
@visitor = self.class::BindSubstitution.new(self)
end

# Returns the human-readable name of the adapter. Use mixed case - one
Expand Down Expand Up @@ -126,6 +125,10 @@ def disconnect!
@connection.disconnect if @connection.connected?
end

def new_column(name, default, sql_type_metadata, null, table_name, default_function = nil, collation = nil, native_type = nil)
::ODBCAdapter::Column.new(name, default, sql_type_metadata, null, table_name, default_function, collation, native_type)
end

protected

def initialize_type_map(map)
Expand Down Expand Up @@ -168,10 +171,6 @@ def translate_exception(exception, message)
end
end

def new_column(name, default, cast_type, sql_type = nil, null = true, native_type = nil, scale = nil, limit = nil)
::ODBCAdapter::Column.new(name, default, cast_type, sql_type, null, native_type, scale, limit)
end

private

def alias_type(map, new_type, old_type)
Expand Down
8 changes: 4 additions & 4 deletions lib/odbc_adapter/adapters/mysql_odbc_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module Adapters
# Overrides specific to MySQL. Mostly taken from
# ActiveRecord::ConnectionAdapters::MySQLAdapter
class MySQLODBCAdapter < ActiveRecord::ConnectionAdapters::ODBCAdapter
class BindSubstitution < Arel::Visitors::MySQL
include Arel::Visitors::BindVisitor
end

PRIMARY_KEY = 'INT(11) NOT NULL AUTO_INCREMENT PRIMARY KEY'

def arel_visitor
Arel::Visitors::MySQL.new(self)
end

def truncate(table_name, name = nil)
execute("TRUNCATE TABLE #{quote_table_name(table_name)}", name)
end
Expand Down
8 changes: 4 additions & 4 deletions lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module Adapters
# Overrides specific to PostgreSQL. Mostly taken from
# ActiveRecord::ConnectionAdapters::PostgreSQLAdapter
class PostgreSQLODBCAdapter < ActiveRecord::ConnectionAdapters::ODBCAdapter
class BindSubstitution < Arel::Visitors::PostgreSQL
include Arel::Visitors::BindVisitor
end

PRIMARY_KEY = 'SERIAL PRIMARY KEY'

def arel_visitor
Arel::Visitors::PostgreSQL.new(self)
end

# Filter for ODBCAdapter#tables
# Omits table from #tables if table_filter returns true
def table_filter(schema_name, table_type)
Expand Down
19 changes: 2 additions & 17 deletions lib/odbc_adapter/column.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@ module ODBCAdapter
class Column < ActiveRecord::ConnectionAdapters::Column
attr_reader :native_type

def initialize(name, default, cast_type, sql_type, null, native_type, scale, limit)
@name = name
@default = default
@cast_type = cast_type
@sql_type = sql_type
@null = null
def initialize(name, default, sql_type_metadata = nil, null = true, table_name = nil, default_function = nil, collation = nil, native_type = nil)
super(name, default, sql_type_metadata, null, table_name, default_function, collation)
@native_type = native_type

if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(sql_type)
set_numeric_params(scale, limit)
end
end

private

def set_numeric_params(scale, limit)
@cast_type.instance_variable_set(:@scale, scale || 0)
@cast_type.instance_variable_set(:@precision, limit)
end
end
end
33 changes: 8 additions & 25 deletions lib/odbc_adapter/database_statements.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@ def select_rows(sql, name = nil)

# Executes the SQL statement in the context of this connection.
# Returns the number of rows affected.
# TODO: Currently ignoring binds until we can get prepared statements working.
def execute(sql, name = nil, binds = [])
log(sql, name) do
@connection.do(sql)
prepared_binds =
prepare_binds_for_database(binds).map { |bind| _type_cast(bind) }
@connection.do(sql, *prepared_binds)
end
end

# Executes +sql+ statement in the context of this connection using
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_query(sql, name = 'SQL', binds = [])
def exec_query(sql, name = 'SQL', binds = [], prepare: false)
log(sql, name) do
stmt = @connection.run(sql)
prepared_binds =
prepare_binds_for_database(binds).map { |bind| _type_cast(bind) }

stmt = @connection.run(sql, *prepared_binds)
columns = stmt.columns
values = stmt.to_a
stmt.drop
Expand Down Expand Up @@ -81,33 +85,12 @@ def default_sequence_name(table, _column)
"#{table}_seq"
end

protected

# Returns the last auto-generated ID from the affected table.
def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
begin
stmt = log(sql, name) { @connection.run(sql) }
table = extract_table_ref_from_insert_sql(sql)

seq = sequence_name || default_sequence_name(table, pk)
res = id_value || last_insert_id(table, seq, stmt)
ensure
stmt.drop unless stmt.nil?
end
res
end

private

def dbms_type_cast(columns, values)
values
end

def extract_table_ref_from_insert_sql(sql)
sql[/into\s+([^\(]*).*values\s*\(/i]
$1.strip if $1
end

# Assume received identifier is in DBMS's data dictionary case.
def format_case(identifier)
case dbms.field_for(ODBC::SQL_IDENTIFIER_CASE)
Expand Down
73 changes: 43 additions & 30 deletions lib/odbc_adapter/schema_statements.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ def native_database_types
@native_database_types ||= ColumnMetadata.new(self).native_database_types
end

# Ensure it's shorter than the maximum identifier length for the current dbms
def index_name(table_name, options)
maximum = dbms.field_for(ODBC::SQL_MAX_IDENTIFIER_LEN) || 255
super(table_name, options)[0...maximum]
end

def current_database
dbms.field_for(ODBC::SQL_DATABASE_NAME).strip
end

# Returns an array of table names, for database tables visible on the
# current connection.
def tables(_name = nil)
Expand All @@ -31,26 +21,9 @@ def tables(_name = nil)
end
end

# Returns an array of Column objects for the table specified by +table_name+.
def columns(table_name, name = nil)
stmt = @connection.columns(native_case(table_name.to_s))
result = stmt.fetch_all || []
stmt.drop

result.each_with_object([]) do |col, cols|
col_name = col[3] # SQLColumns: COLUMN_NAME
col_default = col[12] # SQLColumns: COLUMN_DEF
col_sql_type = col[4] # SQLColumns: DATA_TYPE
col_native_type = col[5] # SQLColumns: TYPE_NAME
col_limit = col[6] # SQLColumns: COLUMN_SIZE
col_scale = col[8] # SQLColumns: DECIMAL_DIGITS

# SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE
col_nullable = nullability(col_name, col[17], col[10])

cast_type = lookup_cast_type(col_sql_type)
cols << new_column(format_case(col_name), col_default, cast_type, col_sql_type, col_nullable, col_native_type, col_scale, col_limit)
end
# Returns an array of view names defined in the database.
def views
[]
end

# Returns an array of indexes for the given table.
Expand Down Expand Up @@ -83,12 +56,52 @@ def indexes(table_name, name = nil)
end
end

# Returns an array of Column objects for the table specified by
# +table_name+.
def columns(table_name, name = nil)
stmt = @connection.columns(native_case(table_name.to_s))
result = stmt.fetch_all || []
stmt.drop

result.each_with_object([]) do |col, cols|
col_name = col[3] # SQLColumns: COLUMN_NAME
col_default = col[12] # SQLColumns: COLUMN_DEF
col_sql_type = col[4] # SQLColumns: DATA_TYPE
col_native_type = col[5] # SQLColumns: TYPE_NAME
col_limit = col[6] # SQLColumns: COLUMN_SIZE
col_scale = col[8] # SQLColumns: DECIMAL_DIGITS

# SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE
col_nullable = nullability(col_name, col[17], col[10])

args = { sql_type: col_native_type, type: col_sql_type, limit: col_limit }
if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(col_sql_type)
args[:scale] = col_scale || 0
args[:precision] = col_limit
end
sql_type_metadata = ActiveRecord::ConnectionAdapters::SqlTypeMetadata.new(**args)

cols << new_column(format_case(col_name), col_default, sql_type_metadata, col_nullable, table_name)
end
end

# Returns just a table's primary key
def primary_key(table_name)
stmt = @connection.primary_keys(native_case(table_name.to_s))
result = stmt.fetch_all || []
stmt.drop unless stmt.nil?
result[0] && result[0][3]
end

# Ensure it's shorter than the maximum identifier length for the current
# dbms
def index_name(table_name, options)
maximum = dbms.field_for(ODBC::SQL_MAX_IDENTIFIER_LEN) || 255
super(table_name, options)[0...maximum]
end

def current_database
dbms.field_for(ODBC::SQL_DATABASE_NAME).strip
end
end
end
2 changes: 1 addition & 1 deletion test/metadata_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

class MetadataTest < Minitest::Test
def test_tables
assert_equal %w[todos users], User.connection.tables.sort
assert_equal %w[ar_internal_metadata todos users], User.connection.tables.sort
end

def test_column_names
Expand Down

0 comments on commit 92eb12c

Please sign in to comment.