diff --git a/Gemfile b/Gemfile index 4fbef2ba..1e6b358a 100644 --- a/Gemfile +++ b/Gemfile @@ -2,5 +2,5 @@ source 'https://rubygems.org' gemspec -gem 'activerecord', '4.2.7.1' +gem 'activerecord', '5.0.1' gem 'pry', '0.10.4' diff --git a/lib/active_record/connection_adapters/odbc_adapter.rb b/lib/active_record/connection_adapters/odbc_adapter.rb index 567ff408..5da1de82 100644 --- a/lib/active_record/connection_adapters/odbc_adapter.rb +++ b/lib/active_record/connection_adapters/odbc_adapter.rb @@ -82,7 +82,6 @@ def initialize(connection, logger, dbms) super(connection, logger) @connection = connection @dbms = dbms - @visitor = self.class::BindSubstitution.new(self) end # Returns the human-readable name of the adapter. Use mixed case - one @@ -126,6 +125,10 @@ def disconnect! @connection.disconnect if @connection.connected? end + def new_column(name, default, sql_type_metadata, null, table_name, default_function = nil, collation = nil, native_type = nil) + ::ODBCAdapter::Column.new(name, default, sql_type_metadata, null, table_name, default_function, collation, native_type) + end + protected def initialize_type_map(map) @@ -168,10 +171,6 @@ def translate_exception(exception, message) end end - def new_column(name, default, cast_type, sql_type = nil, null = true, native_type = nil, scale = nil, limit = nil) - ::ODBCAdapter::Column.new(name, default, cast_type, sql_type, null, native_type, scale, limit) - end - private def alias_type(map, new_type, old_type) diff --git a/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb b/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb index be144227..154c918f 100644 --- a/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb +++ b/lib/odbc_adapter/adapters/mysql_odbc_adapter.rb @@ -3,12 +3,12 @@ module Adapters # Overrides specific to MySQL. Mostly taken from # ActiveRecord::ConnectionAdapters::MySQLAdapter class MySQLODBCAdapter < ActiveRecord::ConnectionAdapters::ODBCAdapter - class BindSubstitution < Arel::Visitors::MySQL - include Arel::Visitors::BindVisitor - end - PRIMARY_KEY = 'INT(11) NOT NULL AUTO_INCREMENT PRIMARY KEY' + def arel_visitor + Arel::Visitors::MySQL.new(self) + end + def truncate(table_name, name = nil) execute("TRUNCATE TABLE #{quote_table_name(table_name)}", name) end diff --git a/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb b/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb index daca1a86..23f9d8c4 100644 --- a/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb +++ b/lib/odbc_adapter/adapters/postgresql_odbc_adapter.rb @@ -3,12 +3,12 @@ module Adapters # Overrides specific to PostgreSQL. Mostly taken from # ActiveRecord::ConnectionAdapters::PostgreSQLAdapter class PostgreSQLODBCAdapter < ActiveRecord::ConnectionAdapters::ODBCAdapter - class BindSubstitution < Arel::Visitors::PostgreSQL - include Arel::Visitors::BindVisitor - end - PRIMARY_KEY = 'SERIAL PRIMARY KEY' + def arel_visitor + Arel::Visitors::PostgreSQL.new(self) + end + # Filter for ODBCAdapter#tables # Omits table from #tables if table_filter returns true def table_filter(schema_name, table_type) diff --git a/lib/odbc_adapter/column.rb b/lib/odbc_adapter/column.rb index c0d62f28..d230de24 100644 --- a/lib/odbc_adapter/column.rb +++ b/lib/odbc_adapter/column.rb @@ -2,24 +2,9 @@ module ODBCAdapter class Column < ActiveRecord::ConnectionAdapters::Column attr_reader :native_type - def initialize(name, default, cast_type, sql_type, null, native_type, scale, limit) - @name = name - @default = default - @cast_type = cast_type - @sql_type = sql_type - @null = null + def initialize(name, default, sql_type_metadata = nil, null = true, table_name = nil, default_function = nil, collation = nil, native_type = nil) + super(name, default, sql_type_metadata, null, table_name, default_function, collation) @native_type = native_type - - if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(sql_type) - set_numeric_params(scale, limit) - end - end - - private - - def set_numeric_params(scale, limit) - @cast_type.instance_variable_set(:@scale, scale || 0) - @cast_type.instance_variable_set(:@precision, limit) end end end diff --git a/lib/odbc_adapter/database_statements.rb b/lib/odbc_adapter/database_statements.rb index 652d2ac1..5e6e29ad 100644 --- a/lib/odbc_adapter/database_statements.rb +++ b/lib/odbc_adapter/database_statements.rb @@ -18,19 +18,23 @@ def select_rows(sql, name = nil) # Executes the SQL statement in the context of this connection. # Returns the number of rows affected. - # TODO: Currently ignoring binds until we can get prepared statements working. def execute(sql, name = nil, binds = []) log(sql, name) do - @connection.do(sql) + prepared_binds = + prepare_binds_for_database(binds).map { |bind| _type_cast(bind) } + @connection.do(sql, *prepared_binds) end end # Executes +sql+ statement in the context of this connection using # +binds+ as the bind substitutes. +name+ is logged along with # the executed +sql+ statement. - def exec_query(sql, name = 'SQL', binds = []) + def exec_query(sql, name = 'SQL', binds = [], prepare: false) log(sql, name) do - stmt = @connection.run(sql) + prepared_binds = + prepare_binds_for_database(binds).map { |bind| _type_cast(bind) } + + stmt = @connection.run(sql, *prepared_binds) columns = stmt.columns values = stmt.to_a stmt.drop @@ -81,33 +85,12 @@ def default_sequence_name(table, _column) "#{table}_seq" end - protected - - # Returns the last auto-generated ID from the affected table. - def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) - begin - stmt = log(sql, name) { @connection.run(sql) } - table = extract_table_ref_from_insert_sql(sql) - - seq = sequence_name || default_sequence_name(table, pk) - res = id_value || last_insert_id(table, seq, stmt) - ensure - stmt.drop unless stmt.nil? - end - res - end - private def dbms_type_cast(columns, values) values end - def extract_table_ref_from_insert_sql(sql) - sql[/into\s+([^\(]*).*values\s*\(/i] - $1.strip if $1 - end - # Assume received identifier is in DBMS's data dictionary case. def format_case(identifier) case dbms.field_for(ODBC::SQL_IDENTIFIER_CASE) diff --git a/lib/odbc_adapter/schema_statements.rb b/lib/odbc_adapter/schema_statements.rb index 83c7f410..7c095b9d 100644 --- a/lib/odbc_adapter/schema_statements.rb +++ b/lib/odbc_adapter/schema_statements.rb @@ -7,16 +7,6 @@ def native_database_types @native_database_types ||= ColumnMetadata.new(self).native_database_types end - # Ensure it's shorter than the maximum identifier length for the current dbms - def index_name(table_name, options) - maximum = dbms.field_for(ODBC::SQL_MAX_IDENTIFIER_LEN) || 255 - super(table_name, options)[0...maximum] - end - - def current_database - dbms.field_for(ODBC::SQL_DATABASE_NAME).strip - end - # Returns an array of table names, for database tables visible on the # current connection. def tables(_name = nil) @@ -31,26 +21,9 @@ def tables(_name = nil) end end - # Returns an array of Column objects for the table specified by +table_name+. - def columns(table_name, name = nil) - stmt = @connection.columns(native_case(table_name.to_s)) - result = stmt.fetch_all || [] - stmt.drop - - result.each_with_object([]) do |col, cols| - col_name = col[3] # SQLColumns: COLUMN_NAME - col_default = col[12] # SQLColumns: COLUMN_DEF - col_sql_type = col[4] # SQLColumns: DATA_TYPE - col_native_type = col[5] # SQLColumns: TYPE_NAME - col_limit = col[6] # SQLColumns: COLUMN_SIZE - col_scale = col[8] # SQLColumns: DECIMAL_DIGITS - - # SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE - col_nullable = nullability(col_name, col[17], col[10]) - - cast_type = lookup_cast_type(col_sql_type) - cols << new_column(format_case(col_name), col_default, cast_type, col_sql_type, col_nullable, col_native_type, col_scale, col_limit) - end + # Returns an array of view names defined in the database. + def views + [] end # Returns an array of indexes for the given table. @@ -83,6 +56,35 @@ def indexes(table_name, name = nil) end end + # Returns an array of Column objects for the table specified by + # +table_name+. + def columns(table_name, name = nil) + stmt = @connection.columns(native_case(table_name.to_s)) + result = stmt.fetch_all || [] + stmt.drop + + result.each_with_object([]) do |col, cols| + col_name = col[3] # SQLColumns: COLUMN_NAME + col_default = col[12] # SQLColumns: COLUMN_DEF + col_sql_type = col[4] # SQLColumns: DATA_TYPE + col_native_type = col[5] # SQLColumns: TYPE_NAME + col_limit = col[6] # SQLColumns: COLUMN_SIZE + col_scale = col[8] # SQLColumns: DECIMAL_DIGITS + + # SQLColumns: IS_NULLABLE, SQLColumns: NULLABLE + col_nullable = nullability(col_name, col[17], col[10]) + + args = { sql_type: col_native_type, type: col_sql_type, limit: col_limit } + if [ODBC::SQL_DECIMAL, ODBC::SQL_NUMERIC].include?(col_sql_type) + args[:scale] = col_scale || 0 + args[:precision] = col_limit + end + sql_type_metadata = ActiveRecord::ConnectionAdapters::SqlTypeMetadata.new(**args) + + cols << new_column(format_case(col_name), col_default, sql_type_metadata, col_nullable, table_name) + end + end + # Returns just a table's primary key def primary_key(table_name) stmt = @connection.primary_keys(native_case(table_name.to_s)) @@ -90,5 +92,16 @@ def primary_key(table_name) stmt.drop unless stmt.nil? result[0] && result[0][3] end + + # Ensure it's shorter than the maximum identifier length for the current + # dbms + def index_name(table_name, options) + maximum = dbms.field_for(ODBC::SQL_MAX_IDENTIFIER_LEN) || 255 + super(table_name, options)[0...maximum] + end + + def current_database + dbms.field_for(ODBC::SQL_DATABASE_NAME).strip + end end end diff --git a/test/metadata_test.rb b/test/metadata_test.rb index 68201a81..b93bd466 100644 --- a/test/metadata_test.rb +++ b/test/metadata_test.rb @@ -2,7 +2,7 @@ class MetadataTest < Minitest::Test def test_tables - assert_equal %w[todos users], User.connection.tables.sort + assert_equal %w[ar_internal_metadata todos users], User.connection.tables.sort end def test_column_names