From a9817045a324588db5deebb795e58d5d26b708b8 Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 30 Mar 2020 14:24:17 +0800 Subject: [PATCH] Move some codes to statement sub package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move some codes to statement sub package Revert change for delete Refactor new engine Fix tests Improve quote policy Names with upper charactor on postgres will need quotes Fix bug Add new quote parameter on tests Fix tests Fix quotes Improve quote policy Fix cache bug Fix map with cols Fix rows bug Improve dialect interface Return sqls for create table Code improvement Fix find alias bug Improve insert Only replace quotes when necessary Move value2interface from session to statement package Remove duplicated code Fix mssql issue with duplicate columns. Fix err,add postgres column ARRAY Fix table name Ignore schema when dbtype is not postgres Support count with cols Fix bug when dump Fix batch insert interface slice be panic Fix master/slave bug Fix pk bug Fix dump/import bug Fix setschema Fix dump bug Oracle : Local Naming Method Stop using github.com/xorm/core Fix find and count bug Move processor function into one file Fix duplicated deleted condition on FindAndCount Don't keep db on dialects Add more tests for FindAndCount Fix dump test Fix postgres schema problem Fix quote with blank Fix lint errors Move column string to standalone method Support session id Improve insert map generating SQL Move all integrations tests to a standalone sub package --- caches/cache.go | 99 + caches/encode.go | 58 + caches/leveldb.go | 94 + caches/leveldb_test.go | 39 + cache_lru.go => caches/lru.go | 20 +- cache_lru_test.go => caches/lru_test.go | 6 +- caches/manager.go | 56 + .../memory_store.go | 6 +- .../memory_store_test.go | 6 +- context_test.go | 32 - contexts/context_cache.go | 30 + convert.go | 178 +- convert/conversion.go | 12 + core/db.go | 290 ++ core/db_test.go | 684 ++++ core/error.go | 14 + core/interface.go | 22 + core/rows.go | 338 ++ core/scan.go | 66 + core/stmt.go | 228 ++ core/tx.go | 233 ++ dialects/dialect.go | 284 ++ dialects/driver.go | 57 + dialects/filter.go | 43 + dialects/filter_test.go | 21 + dialects/gen_reserved.sh | 6 + dialect_mssql.go => dialects/mssql.go | 217 +- .../mssql_test.go | 10 +- dialect_mysql.go => dialects/mysql.go | 251 +- dialect_oracle.go => dialects/oracle.go | 295 +- dialects/oracle_test.go | 34 + dialects/pg_reserved.txt | 746 +++++ dialect_postgres.go => dialects/postgres.go | 385 ++- .../postgres_test.go | 19 +- dialects/quote.go | 15 + dialect_sqlite3.go => dialects/sqlite3.go | 235 +- .../sqlite3_test.go | 2 +- dialects/table_name.go | 89 + .../table_name_test.go | 12 +- dialects/time.go | 49 + engine.go | 851 ++--- engine_cond.go | 232 -- engine_context.go | 28 - engine_context_test.go | 28 - engine_group.go | 43 +- engine_group_policy.go | 2 + engine_table.go | 108 - engine_test.go | 89 - engineplus.go | 2 +- error.go | 4 + fswatcher.go | 6 +- helpers.go | 343 -- helpers_plus.go | 11 +- helpers_test.go | 27 - helpler_time.go | 21 - cache_test.go => integrations/cache_test.go | 22 +- integrations/engine_group_test.go | 35 + integrations/engine_test.go | 141 + types.go => integrations/main_test.go | 13 +- .../processors_test.go | 102 +- rows_test.go => integrations/rows_test.go | 24 +- .../session_cols_test.go | 12 +- .../session_cond_test.go | 26 +- .../session_delete_test.go | 20 +- .../session_exist_test.go | 31 +- .../session_find_test.go | 416 ++- .../session_get_test.go | 99 +- .../session_insert_test.go | 521 ++- .../session_iterate_test.go | 6 +- integrations/session_pk_test.go | 673 ++++ .../session_query_test.go | 30 +- .../session_raw_test.go | 4 +- .../session_schema_test.go | 53 +- .../session_stats_test.go | 40 +- .../session_test.go | 17 +- .../session_tx_test.go | 42 +- .../session_update_test.go | 669 ++-- integrations/tags_test.go | 1329 ++++++++ xorm_test.go => integrations/tests.go | 120 +- time_test.go => integrations/time_test.go | 69 +- .../types_null_test.go | 252 +- types_test.go => integrations/types_test.go | 90 +- interface.go | 47 +- internal/json/json.go | 31 + internal/statements/cache.go | 79 + internal/statements/column_map.go | 66 + .../statements/expr_param.go | 49 +- internal/statements/insert.go | 207 ++ internal/statements/pk.go | 79 + internal/statements/query.go | 534 +++ internal/statements/statement.go | 996 ++++++ .../statements/statement_args.go | 70 +- .../statements/statement_test.go | 168 +- internal/statements/update.go | 295 ++ internal/statements/values.go | 154 + internal/utils/name.go | 13 + internal/utils/reflect.go | 13 + internal/utils/slice.go | 22 + internal/utils/sql.go | 19 + internal/utils/strings.go | 30 + uuid.go => internal/utils/uuid.go | 2 +- internal/utils/zero.go | 145 + internal/utils/zero_test.go | 73 + logger.go => log/logger.go | 106 +- log/logger_context.go | 121 + syslogger.go => log/syslogger.go | 14 +- migrate/migrate.go | 4 +- migrate/migrate_test.go | 2 +- names/mapper.go | 258 ++ names/mapper_test.go | 49 + names/table_name.go | 56 + names/table_name_test.go | 140 + processors.go | 66 + rows.go | 47 +- schemas/column.go | 117 + schemas/index.go | 72 + schemas/pk.go | 41 + schemas/pk_test.go | 36 + schemas/quote.go | 240 ++ schemas/quote_test.go | 181 + schemas/table.go | 146 + schemas/table_test.go | 111 + schemas/type.go | 336 ++ session.go | 340 +- session_cols.go | 16 +- session_cond.go | 17 +- session_context.go | 23 - session_context_test.go | 36 - session_convert.go | 142 +- session_delete.go | 61 +- session_exist.go | 81 +- session_find.go | 187 +- session_get.go | 57 +- session_insert.go | 549 +--- session_iterate.go | 14 +- session_pk_test.go | 1199 ------- session_plus.go | 122 +- session_query.go | 116 +- session_raw.go | 54 +- session_schema.go | 172 +- session_stats.go | 31 +- session_sum_test.go | 152 - session_tx.go | 110 +- session_tx_plus.go | 19 +- session_update.go | 203 +- sql_executor.go | 30 +- statement.go | 1291 -------- statement_columnmap.go | 35 - statement_quote.go | 19 - table_name.go | 31 - table_name_test.go | 73 - tag_cache_test.go | 35 - tag_extends_test.go | 608 ---- tag_id_test.go | 85 - tag_test.go | 601 ---- tag_version_test.go | 242 -- tags/parser.go | 307 ++ tags/parser_test.go | 44 + tag.go => tags/tag.go | 101 +- tags/tag_test.go | 30 + transaction.go | 26 - transancation_test.go | 52 - .../github.com/syndtr/goleveldb/.travis.yml | 12 + vendor/github.com/syndtr/goleveldb/LICENSE | 24 + vendor/github.com/syndtr/goleveldb/README.md | 107 + vendor/github.com/syndtr/goleveldb/go.mod | 9 + vendor/github.com/syndtr/goleveldb/go.sum | 25 + .../syndtr/goleveldb/leveldb/batch.go | 354 ++ .../syndtr/goleveldb/leveldb/batch_test.go | 147 + .../syndtr/goleveldb/leveldb/bench_test.go | 507 +++ .../goleveldb/leveldb/cache/bench_test.go | 29 + .../syndtr/goleveldb/leveldb/cache/cache.go | 704 ++++ .../goleveldb/leveldb/cache/cache_test.go | 563 ++++ .../syndtr/goleveldb/leveldb/cache/lru.go | 195 ++ .../syndtr/goleveldb/leveldb/comparer.go | 67 + .../leveldb/comparer/bytes_comparer.go | 51 + .../goleveldb/leveldb/comparer/comparer.go | 57 + .../syndtr/goleveldb/leveldb/corrupt_test.go | 498 +++ .../github.com/syndtr/goleveldb/leveldb/db.go | 1205 +++++++ .../syndtr/goleveldb/leveldb/db_compaction.go | 865 +++++ .../syndtr/goleveldb/leveldb/db_iter.go | 369 +++ .../syndtr/goleveldb/leveldb/db_snapshot.go | 187 ++ .../syndtr/goleveldb/leveldb/db_state.go | 239 ++ .../syndtr/goleveldb/leveldb/db_test.go | 2926 +++++++++++++++++ .../goleveldb/leveldb/db_transaction.go | 335 ++ .../syndtr/goleveldb/leveldb/db_util.go | 102 + .../syndtr/goleveldb/leveldb/db_write.go | 464 +++ .../syndtr/goleveldb/leveldb/doc.go | 92 + .../syndtr/goleveldb/leveldb/errors.go | 20 + .../syndtr/goleveldb/leveldb/errors/errors.go | 78 + .../syndtr/goleveldb/leveldb/external_test.go | 117 + .../syndtr/goleveldb/leveldb/filter.go | 31 + .../syndtr/goleveldb/leveldb/filter/bloom.go | 116 + .../goleveldb/leveldb/filter/bloom_test.go | 142 + .../syndtr/goleveldb/leveldb/filter/filter.go | 60 + .../goleveldb/leveldb/iterator/array_iter.go | 184 ++ .../leveldb/iterator/array_iter_test.go | 30 + .../leveldb/iterator/indexed_iter.go | 242 ++ .../leveldb/iterator/indexed_iter_test.go | 83 + .../syndtr/goleveldb/leveldb/iterator/iter.go | 132 + .../leveldb/iterator/iter_suite_test.go | 11 + .../goleveldb/leveldb/iterator/merged_iter.go | 304 ++ .../leveldb/iterator/merged_iter_test.go | 60 + .../goleveldb/leveldb/journal/journal.go | 524 +++ .../goleveldb/leveldb/journal/journal_test.go | 818 +++++ .../syndtr/goleveldb/leveldb/key.go | 143 + .../syndtr/goleveldb/leveldb/key_test.go | 133 + .../goleveldb/leveldb/leveldb_suite_test.go | 11 + .../goleveldb/leveldb/memdb/bench_test.go | 75 + .../syndtr/goleveldb/leveldb/memdb/memdb.go | 479 +++ .../leveldb/memdb/memdb_suite_test.go | 11 + .../goleveldb/leveldb/memdb/memdb_test.go | 135 + .../syndtr/goleveldb/leveldb/opt/options.go | 716 ++++ .../syndtr/goleveldb/leveldb/options.go | 107 + .../syndtr/goleveldb/leveldb/session.go | 239 ++ .../goleveldb/leveldb/session_compaction.go | 326 ++ .../goleveldb/leveldb/session_record.go | 323 ++ .../goleveldb/leveldb/session_record_test.go | 62 + .../syndtr/goleveldb/leveldb/session_util.go | 483 +++ .../syndtr/goleveldb/leveldb/storage.go | 63 + .../goleveldb/leveldb/storage/file_storage.go | 671 ++++ .../leveldb/storage/file_storage_nacl.go | 34 + .../leveldb/storage/file_storage_plan9.go | 63 + .../leveldb/storage/file_storage_solaris.go | 81 + .../leveldb/storage/file_storage_test.go | 400 +++ .../leveldb/storage/file_storage_unix.go | 98 + .../leveldb/storage/file_storage_windows.go | 78 + .../goleveldb/leveldb/storage/mem_storage.go | 222 ++ .../leveldb/storage/mem_storage_test.go | 117 + .../goleveldb/leveldb/storage/storage.go | 187 ++ .../syndtr/goleveldb/leveldb/table.go | 600 ++++ .../goleveldb/leveldb/table/block_test.go | 139 + .../syndtr/goleveldb/leveldb/table/reader.go | 1139 +++++++ .../syndtr/goleveldb/leveldb/table/table.go | 177 + .../leveldb/table/table_suite_test.go | 11 + .../goleveldb/leveldb/table/table_test.go | 123 + .../syndtr/goleveldb/leveldb/table/writer.go | 375 +++ .../syndtr/goleveldb/leveldb/table_test.go | 159 + .../syndtr/goleveldb/leveldb/testutil/db.go | 222 ++ .../goleveldb/leveldb/testutil/ginkgo.go | 21 + .../syndtr/goleveldb/leveldb/testutil/iter.go | 327 ++ .../syndtr/goleveldb/leveldb/testutil/kv.go | 352 ++ .../goleveldb/leveldb/testutil/kvtest.go | 212 ++ .../goleveldb/leveldb/testutil/storage.go | 696 ++++ .../syndtr/goleveldb/leveldb/testutil/util.go | 171 + .../syndtr/goleveldb/leveldb/testutil_test.go | 91 + .../syndtr/goleveldb/leveldb/util.go | 98 + .../syndtr/goleveldb/leveldb/util/buffer.go | 293 ++ .../goleveldb/leveldb/util/buffer_pool.go | 239 ++ .../goleveldb/leveldb/util/buffer_test.go | 369 +++ .../syndtr/goleveldb/leveldb/util/crc32.go | 30 + .../syndtr/goleveldb/leveldb/util/hash.go | 48 + .../goleveldb/leveldb/util/hash_test.go | 46 + .../syndtr/goleveldb/leveldb/util/range.go | 32 + .../syndtr/goleveldb/leveldb/util/util.go | 73 + .../syndtr/goleveldb/leveldb/version.go | 573 ++++ .../syndtr/goleveldb/leveldb/version_test.go | 431 +++ .../goleveldb/manualtest/dbstress/key.go | 137 + .../goleveldb/manualtest/dbstress/main.go | 630 ++++ .../goleveldb/manualtest/filelock/main.go | 85 + xorm.go | 89 +- 261 files changed, 40795 insertions(+), 10013 deletions(-) create mode 100644 caches/cache.go create mode 100644 caches/encode.go create mode 100644 caches/leveldb.go create mode 100644 caches/leveldb_test.go rename cache_lru.go => caches/lru.go (93%) rename cache_lru_test.go => caches/lru_test.go (93%) create mode 100644 caches/manager.go rename cache_memory_store.go => caches/memory_store.go (92%) rename cache_memory_store_test.go => caches/memory_store_test.go (91%) delete mode 100644 context_test.go create mode 100644 contexts/context_cache.go create mode 100644 convert/conversion.go create mode 100644 core/db.go create mode 100644 core/db_test.go create mode 100644 core/error.go create mode 100644 core/interface.go create mode 100644 core/rows.go create mode 100644 core/scan.go create mode 100644 core/stmt.go create mode 100644 core/tx.go create mode 100644 dialects/dialect.go create mode 100644 dialects/driver.go create mode 100644 dialects/filter.go create mode 100644 dialects/filter_test.go create mode 100755 dialects/gen_reserved.sh rename dialect_mssql.go => dialects/mssql.go (76%) rename dialect_mssql_test.go => dialects/mssql_test.go (83%) rename dialect_mysql.go => dialects/mysql.go (76%) rename dialect_oracle.go => dialects/oracle.go (79%) create mode 100644 dialects/oracle_test.go create mode 100644 dialects/pg_reserved.txt rename dialect_postgres.go => dialects/postgres.go (83%) rename dialect_postgres_test.go => dialects/postgres_test.go (91%) create mode 100644 dialects/quote.go rename dialect_sqlite3.go => dialects/sqlite3.go (67%) rename dialect_sqlite3_test.go => dialects/sqlite3_test.go (97%) create mode 100644 dialects/table_name.go rename engine_table_test.go => dialects/table_name_test.go (59%) create mode 100644 dialects/time.go delete mode 100644 engine_cond.go delete mode 100644 engine_context.go delete mode 100644 engine_context_test.go delete mode 100644 engine_table.go delete mode 100644 engine_test.go delete mode 100644 helpers.go delete mode 100644 helpers_test.go delete mode 100644 helpler_time.go rename cache_test.go => integrations/cache_test.go (88%) create mode 100644 integrations/engine_group_test.go create mode 100644 integrations/engine_test.go rename types.go => integrations/main_test.go (55%) rename processors_test.go => integrations/processors_test.go (90%) rename rows_test.go => integrations/rows_test.go (87%) rename session_cols_test.go => integrations/session_cols_test.go (93%) rename session_cond_test.go => integrations/session_cond_test.go (88%) rename session_delete_test.go => integrations/session_delete_test.go (93%) rename session_exist_test.go => integrations/session_exist_test.go (86%) rename session_find_test.go => integrations/session_find_test.go (66%) rename session_get_test.go => integrations/session_get_test.go (88%) rename session_insert_test.go => integrations/session_insert_test.go (72%) rename session_iterate_test.go => integrations/session_iterate_test.go (96%) create mode 100644 integrations/session_pk_test.go rename session_query_test.go => integrations/session_query_test.go (91%) rename session_raw_test.go => integrations/session_raw_test.go (94%) rename session_schema_test.go => integrations/session_schema_test.go (89%) rename session_stats_test.go => integrations/session_stats_test.go (87%) rename session_test.go => integrations/session_test.go (70%) rename session_tx_test.go => integrations/session_tx_test.go (84%) rename session_update_test.go => integrations/session_update_test.go (76%) create mode 100644 integrations/tags_test.go rename xorm_test.go => integrations/tests.go (55%) rename time_test.go => integrations/time_test.go (90%) rename types_null_test.go => integrations/types_null_test.go (54%) rename types_test.go => integrations/types_test.go (75%) create mode 100644 internal/json/json.go create mode 100644 internal/statements/cache.go create mode 100644 internal/statements/column_map.go rename statement_exprparam.go => internal/statements/expr_param.go (64%) create mode 100644 internal/statements/insert.go create mode 100644 internal/statements/pk.go create mode 100644 internal/statements/query.go create mode 100644 internal/statements/statement.go rename statement_args.go => internal/statements/statement_args.go (63%) rename statement_test.go => internal/statements/statement_test.go (56%) create mode 100644 internal/statements/update.go create mode 100644 internal/statements/values.go create mode 100644 internal/utils/name.go create mode 100644 internal/utils/reflect.go create mode 100644 internal/utils/slice.go create mode 100644 internal/utils/sql.go create mode 100644 internal/utils/strings.go rename uuid.go => internal/utils/uuid.go (99%) create mode 100644 internal/utils/zero.go create mode 100644 internal/utils/zero_test.go rename logger.go => log/logger.go (64%) create mode 100644 log/logger_context.go rename syslogger.go => log/syslogger.go (88%) create mode 100644 names/mapper.go create mode 100644 names/mapper_test.go create mode 100644 names/table_name.go create mode 100644 names/table_name_test.go create mode 100644 schemas/column.go create mode 100644 schemas/index.go create mode 100644 schemas/pk.go create mode 100644 schemas/pk_test.go create mode 100644 schemas/quote.go create mode 100644 schemas/quote_test.go create mode 100644 schemas/table.go create mode 100644 schemas/table_test.go create mode 100644 schemas/type.go delete mode 100644 session_context.go delete mode 100644 session_context_test.go delete mode 100644 session_pk_test.go delete mode 100644 session_sum_test.go delete mode 100644 statement.go delete mode 100644 statement_columnmap.go delete mode 100644 statement_quote.go delete mode 100644 table_name.go delete mode 100644 table_name_test.go delete mode 100644 tag_cache_test.go delete mode 100644 tag_extends_test.go delete mode 100644 tag_id_test.go delete mode 100644 tag_test.go delete mode 100644 tag_version_test.go create mode 100644 tags/parser.go create mode 100644 tags/parser_test.go rename tag.go => tags/tag.go (72%) create mode 100644 tags/tag_test.go delete mode 100644 transaction.go delete mode 100644 transancation_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/.travis.yml create mode 100644 vendor/github.com/syndtr/goleveldb/LICENSE create mode 100644 vendor/github.com/syndtr/goleveldb/README.md create mode 100644 vendor/github.com/syndtr/goleveldb/go.mod create mode 100644 vendor/github.com/syndtr/goleveldb/go.sum create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/batch.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/batch_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/bench_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/cache/bench_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/cache/cache.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/cache/lru.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/comparer.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/corrupt_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_snapshot.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_state.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_util.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/db_write.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/doc.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/errors.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/errors/errors.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/external_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/filter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/filter/filter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/journal/journal.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/key.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/key_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/options.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/session.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/session_record.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/session_record_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/session_util.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_nacl.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/storage/storage.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/block_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/table.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/table_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table/writer.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/table_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/db.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/iter.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/kv.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/storage.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil/util.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/testutil_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/buffer.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/crc32.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/hash.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/hash_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/range.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/util/util.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/version.go create mode 100644 vendor/github.com/syndtr/goleveldb/leveldb/version_test.go create mode 100644 vendor/github.com/syndtr/goleveldb/manualtest/dbstress/key.go create mode 100644 vendor/github.com/syndtr/goleveldb/manualtest/dbstress/main.go create mode 100644 vendor/github.com/syndtr/goleveldb/manualtest/filelock/main.go diff --git a/caches/cache.go b/caches/cache.go new file mode 100644 index 0000000..afc49b1 --- /dev/null +++ b/caches/cache.go @@ -0,0 +1,99 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "strings" + "time" + + "github.com/xormplus/xorm/schemas" +) + +const ( + // CacheExpired is default cache expired time + CacheExpired = 60 * time.Minute + // CacheMaxMemory is not use now + CacheMaxMemory = 256 + // CacheGcInterval represents interval time to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // CacheGcMaxRemoved represents max nodes removed when gc + CacheGcMaxRemoved = 20 +) + +// list all the errors +var ( + ErrCacheMiss = errors.New("xorm/cache: key not found") + ErrNotStored = errors.New("xorm/cache: not stored") + // ErrNotExist record does not exist error + ErrNotExist = errors.New("Record does not exist") +) + +// CacheStore is a interface to store cache +type CacheStore interface { + // key is primary key or composite primary key + // value is struct's pointer + // key format : -p--... + Put(key string, value interface{}) error + Get(key string) (interface{}, error) + Del(key string) error +} + +// Cacher is an interface to provide cache +// id format : u--... +type Cacher interface { + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id string) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id string, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id string) + ClearIds(tableName string) + ClearBeans(tableName string) +} + +func encodeIds(ids []schemas.PK) (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(ids) + + return buf.String(), err +} + +func decodeIds(s string) ([]schemas.PK, error) { + pks := make([]schemas.PK, 0) + + dec := gob.NewDecoder(strings.NewReader(s)) + err := dec.Decode(&pks) + + return pks, err +} + +// GetCacheSql returns cacher PKs via SQL +func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]schemas.PK, error) { + bytes := m.GetIds(tableName, GenSqlKey(sql, args)) + if bytes == nil { + return nil, errors.New("Not Exist") + } + return decodeIds(bytes.(string)) +} + +// PutCacheSql puts cacher SQL and PKs +func PutCacheSql(m Cacher, ids []schemas.PK, tableName, sql string, args interface{}) error { + bytes, err := encodeIds(ids) + if err != nil { + return err + } + m.PutIds(tableName, GenSqlKey(sql, args), bytes) + return nil +} + +// GenSqlKey generates cache key +func GenSqlKey(sql string, args interface{}) string { + return fmt.Sprintf("%v-%v", sql, args) +} diff --git a/caches/encode.go b/caches/encode.go new file mode 100644 index 0000000..4ba3992 --- /dev/null +++ b/caches/encode.go @@ -0,0 +1,58 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/json" + "fmt" + "io" +) + +// md5 hash string +func Md5(str string) string { + m := md5.New() + io.WriteString(m, str) + return fmt.Sprintf("%x", m.Sum(nil)) +} +func Encode(data interface{}) ([]byte, error) { + //return JsonEncode(data) + return GobEncode(data) +} + +func Decode(data []byte, to interface{}) error { + //return JsonDecode(data, to) + return GobDecode(data, to) +} + +func GobEncode(data interface{}) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(&data) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func GobDecode(data []byte, to interface{}) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(to) +} + +func JsonEncode(data interface{}) ([]byte, error) { + val, err := json.Marshal(data) + if err != nil { + return nil, err + } + return val, nil +} + +func JsonDecode(data []byte, to interface{}) error { + return json.Unmarshal(data, to) +} diff --git a/caches/leveldb.go b/caches/leveldb.go new file mode 100644 index 0000000..d1a177a --- /dev/null +++ b/caches/leveldb.go @@ -0,0 +1,94 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "log" + + "github.com/syndtr/goleveldb/leveldb" +) + +// LevelDBStore implements CacheStore provide local machine +type LevelDBStore struct { + store *leveldb.DB + Debug bool + v interface{} +} + +var _ CacheStore = &LevelDBStore{} + +func NewLevelDBStore(dbfile string) (*LevelDBStore, error) { + db := &LevelDBStore{} + h, err := leveldb.OpenFile(dbfile, nil) + if err != nil { + return nil, err + } + db.store = h + return db, nil +} + +func (s *LevelDBStore) Put(key string, value interface{}) error { + val, err := Encode(value) + if err != nil { + if s.Debug { + log.Println("[LevelDB]EncodeErr: ", err, "Key:", key) + } + return err + } + err = s.store.Put([]byte(key), val, nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]PutErr: ", err, "Key:", key) + } + return err + } + if s.Debug { + log.Println("[LevelDB]Put: ", key) + } + return err +} + +func (s *LevelDBStore) Get(key string) (interface{}, error) { + data, err := s.store.Get([]byte(key), nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]GetErr: ", err, "Key:", key) + } + if err == leveldb.ErrNotFound { + return nil, ErrNotExist + } + return nil, err + } + + err = Decode(data, &s.v) + if err != nil { + if s.Debug { + log.Println("[LevelDB]DecodeErr: ", err, "Key:", key) + } + return nil, err + } + if s.Debug { + log.Println("[LevelDB]Get: ", key, s.v) + } + return s.v, err +} + +func (s *LevelDBStore) Del(key string) error { + err := s.store.Delete([]byte(key), nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]DelErr: ", err, "Key:", key) + } + return err + } + if s.Debug { + log.Println("[LevelDB]Del: ", key) + } + return err +} + +func (s *LevelDBStore) Close() { + s.store.Close() +} diff --git a/caches/leveldb_test.go b/caches/leveldb_test.go new file mode 100644 index 0000000..35981db --- /dev/null +++ b/caches/leveldb_test.go @@ -0,0 +1,39 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLevelDBStore(t *testing.T) { + store, err := NewLevelDBStore("./level.db") + assert.NoError(t, err) + + var kvs = map[string]interface{}{ + "a": "b", + } + for k, v := range kvs { + assert.NoError(t, store.Put(k, v)) + } + + for k, v := range kvs { + val, err := store.Get(k) + assert.NoError(t, err) + assert.EqualValues(t, v, val) + } + + for k := range kvs { + err := store.Del(k) + assert.NoError(t, err) + } + + for k := range kvs { + _, err := store.Get(k) + assert.EqualValues(t, ErrNotExist, err) + } +} diff --git a/cache_lru.go b/caches/lru.go similarity index 93% rename from cache_lru.go rename to caches/lru.go index 74f7a38..6b45ac9 100644 --- a/cache_lru.go +++ b/caches/lru.go @@ -2,15 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "container/list" "fmt" "sync" "time" - - "github.com/xormplus/core" ) // LRUCacher implments cache object facilities @@ -19,7 +17,7 @@ type LRUCacher struct { sqlList *list.List idIndex map[string]map[string]*list.Element sqlIndex map[string]map[string]*list.Element - store core.CacheStore + store CacheStore mutex sync.Mutex MaxElementSize int Expired time.Duration @@ -27,15 +25,15 @@ type LRUCacher struct { } // NewLRUCacher creates a cacher -func NewLRUCacher(store core.CacheStore, maxElementSize int) *LRUCacher { +func NewLRUCacher(store CacheStore, maxElementSize int) *LRUCacher { return NewLRUCacher2(store, 3600*time.Second, maxElementSize) } // NewLRUCacher2 creates a cache include different params -func NewLRUCacher2(store core.CacheStore, expired time.Duration, maxElementSize int) *LRUCacher { +func NewLRUCacher2(store CacheStore, expired time.Duration, maxElementSize int) *LRUCacher { cacher := &LRUCacher{store: store, idList: list.New(), sqlList: list.New(), Expired: expired, - GcInterval: core.CacheGcInterval, MaxElementSize: maxElementSize, + GcInterval: CacheGcInterval, MaxElementSize: maxElementSize, sqlIndex: make(map[string]map[string]*list.Element), idIndex: make(map[string]map[string]*list.Element), } @@ -57,7 +55,7 @@ func (m *LRUCacher) GC() { defer m.mutex.Unlock() var removedNum int for e := m.idList.Front(); e != nil; { - if removedNum <= core.CacheGcMaxRemoved && + if removedNum <= CacheGcMaxRemoved && time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -71,7 +69,7 @@ func (m *LRUCacher) GC() { removedNum = 0 for e := m.sqlList.Front(); e != nil; { - if removedNum <= core.CacheGcMaxRemoved && + if removedNum <= CacheGcMaxRemoved && time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -268,11 +266,11 @@ type sqlNode struct { } func genSQLKey(sql string, args interface{}) string { - return fmt.Sprintf("%v-%v", sql, args) + return fmt.Sprintf("%s-%v", sql, args) } func genID(prefix string, id string) string { - return fmt.Sprintf("%v-%v", prefix, id) + return fmt.Sprintf("%s-%s", prefix, id) } func newIDNode(tbName string, id string) *idNode { diff --git a/cache_lru_test.go b/caches/lru_test.go similarity index 93% rename from cache_lru_test.go rename to caches/lru_test.go index 060357f..1b1e9c5 100644 --- a/cache_lru_test.go +++ b/caches/lru_test.go @@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "testing" "github.com/stretchr/testify/assert" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" ) func TestLRUCache(t *testing.T) { @@ -20,7 +20,7 @@ func TestLRUCache(t *testing.T) { cacher := NewLRUCacher(store, 10000) tableName := "cache_object1" - pks := []core.PK{ + pks := []schemas.PK{ {1}, {2}, } diff --git a/caches/manager.go b/caches/manager.go new file mode 100644 index 0000000..0504521 --- /dev/null +++ b/caches/manager.go @@ -0,0 +1,56 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import "sync" + +type Manager struct { + cacher Cacher + disableGlobalCache bool + + cachers map[string]Cacher + cacherLock sync.RWMutex +} + +func NewManager() *Manager { + return &Manager{ + cachers: make(map[string]Cacher), + } +} + +// SetDisableGlobalCache disable global cache or not +func (mgr *Manager) SetDisableGlobalCache(disable bool) { + if mgr.disableGlobalCache != disable { + mgr.disableGlobalCache = disable + } +} + +func (mgr *Manager) SetCacher(tableName string, cacher Cacher) { + mgr.cacherLock.Lock() + mgr.cachers[tableName] = cacher + mgr.cacherLock.Unlock() +} + +func (mgr *Manager) GetCacher(tableName string) Cacher { + var cacher Cacher + var ok bool + mgr.cacherLock.RLock() + cacher, ok = mgr.cachers[tableName] + mgr.cacherLock.RUnlock() + if !ok && !mgr.disableGlobalCache { + cacher = mgr.cacher + } + return cacher +} + +// SetDefaultCacher set the default cacher. Xorm's default not enable cacher. +func (mgr *Manager) SetDefaultCacher(cacher Cacher) { + mgr.cacher = cacher +} + +// GetDefaultCacher returns the default cacher +func (mgr *Manager) GetDefaultCacher() Cacher { + return mgr.cacher +} diff --git a/cache_memory_store.go b/caches/memory_store.go similarity index 92% rename from cache_memory_store.go rename to caches/memory_store.go index b07898b..f16254d 100644 --- a/cache_memory_store.go +++ b/caches/memory_store.go @@ -2,15 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "sync" - - "github.com/xormplus/core" ) -var _ core.CacheStore = NewMemoryStore() +var _ CacheStore = NewMemoryStore() // MemoryStore represents in-memory store type MemoryStore struct { diff --git a/cache_memory_store_test.go b/caches/memory_store_test.go similarity index 91% rename from cache_memory_store_test.go rename to caches/memory_store_test.go index fc27ae3..12db4ea 100644 --- a/cache_memory_store_test.go +++ b/caches/memory_store_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "testing" @@ -25,12 +25,12 @@ func TestMemoryStore(t *testing.T) { assert.EqualValues(t, v, val) } - for k, _ := range kvs { + for k := range kvs { err := store.Del(k) assert.NoError(t, err) } - for k, _ := range kvs { + for k := range kvs { _, err := store.Get(k) assert.EqualValues(t, ErrNotExist, err) } diff --git a/context_test.go b/context_test.go deleted file mode 100644 index adc651d..0000000 --- a/context_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestPingContext(t *testing.T) { - assert.NoError(t, prepareEngine()) - - ctx, canceled := context.WithTimeout(context.Background(), 10*time.Second) - defer canceled() - - err := testEngine.(*Engine).PingContext(ctx) - assert.NoError(t, err) - - // TODO: Since EngineInterface should be compitable with old Go version, PingContext is not supported. - /* - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - err := testEngine.PingContext(ctx) - assert.NoError(t, err) - */ -} diff --git a/contexts/context_cache.go b/contexts/context_cache.go new file mode 100644 index 0000000..0d0f0f0 --- /dev/null +++ b/contexts/context_cache.go @@ -0,0 +1,30 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package contexts + +// ContextCache is the interface that operates the cache data. +type ContextCache interface { + // Put puts value into cache with key. + Put(key string, val interface{}) + // Get gets cached value by given key. + Get(key string) interface{} +} + +type memoryContextCache map[string]interface{} + +// NewMemoryContextCache return memoryContextCache +func NewMemoryContextCache() memoryContextCache { + return make(map[string]interface{}) +} + +// Put puts value into cache with key. +func (m memoryContextCache) Put(key string, val interface{}) { + m[key] = val +} + +// Get gets cached value by given key. +func (m memoryContextCache) Get(key string) interface{} { + return m[key] +} diff --git a/convert.go b/convert.go index 131dd1a..d0cadca 100644 --- a/convert.go +++ b/convert.go @@ -12,6 +12,8 @@ import ( "fmt" "math" "reflect" + + // "runtime" "strconv" "time" ) @@ -28,11 +30,10 @@ func strconvErr(err error) error { func cloneBytes(b []byte) []byte { if b == nil { return nil - } else { - c := make([]byte, len(b)) - copy(c, b) - return c } + c := make([]byte, len(b)) + copy(c, b) + return c } func asString(src interface{}) string { @@ -288,66 +289,141 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } -func convertFloat(v interface{}) (float64, error) { - switch v.(type) { - case float32: - return float64(v.(float32)), nil - case float64: - return v.(float64), nil - case string: - i, err := strconv.ParseFloat(v.(string), 64) +func asBool(bs []byte) (bool, error) { + if len(bs) == 0 { + return false, nil + } + if bs[0] == 0x00 { + return false, nil + } else if bs[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(bs)) +} + +// str2PK convert string value to primary key value according to tp +func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { + var err error + var result interface{} + var defReturn = reflect.Zero(tp) + + switch tp.Kind() { + case reflect.Int: + result, err = strconv.Atoi(s) if err != nil { - return 0, err + return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) } - return i, nil - case []byte: - i, err := strconv.ParseFloat(string(v.([]byte)), 64) + case reflect.Int8: + x, err := strconv.Atoi(s) if err != nil { - return 0, err + return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - -func convertInt(v interface{}) (int64, error) { - switch v.(type) { - case int: - return int64(v.(int)), nil - case int8: - return int64(v.(int8)), nil - case int16: - return int64(v.(int16)), nil - case int32: - return int64(v.(int32)), nil - case int64: - return v.(int64), nil - case []byte: - i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) + result = int8(x) + case reflect.Int16: + x, err := strconv.Atoi(s) if err != nil { - return 0, err + return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) } - return i, nil - case string: - i, err := strconv.ParseInt(v.(string), 10, 64) + result = int16(x) + case reflect.Int32: + x, err := strconv.Atoi(s) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) + } + result = int32(x) + case reflect.Int64: + result, err = strconv.ParseInt(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) + } + case reflect.Uint: + x, err := strconv.ParseUint(s, 10, 64) if err != nil { - return 0, err + return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) } - return i, nil + result = uint(x) + case reflect.Uint8: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) + } + result = uint8(x) + case reflect.Uint16: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) + } + result = uint16(x) + case reflect.Uint32: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) + } + result = uint32(x) + case reflect.Uint64: + result, err = strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) + } + case reflect.String: + result = s + default: + return defReturn, errors.New("unsupported convert type") } - return 0, fmt.Errorf("unsupported type: %v", v) + return reflect.ValueOf(result).Convert(tp), nil } -func asBool(bs []byte) (bool, error) { - if len(bs) == 0 { - return false, nil +func str2PK(s string, tp reflect.Type) (interface{}, error) { + v, err := str2PKValue(s, tp) + if err != nil { + return nil, err } - if bs[0] == 0x00 { - return false, nil - } else if bs[0] == 0x01 { - return true, nil + return v.Interface(), nil +} + +func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { + var v interface{} + kind := tp.Kind() + + if kind == reflect.Ptr { + kind = tp.Elem().Kind() } - return strconv.ParseBool(string(bs)) + + switch kind { + case reflect.Int16: + temp := int16(id) + v = &temp + case reflect.Int32: + temp := int32(id) + v = &temp + case reflect.Int: + temp := int(id) + v = &temp + case reflect.Int64: + temp := id + v = &temp + case reflect.Uint16: + temp := uint16(id) + v = &temp + case reflect.Uint32: + temp := uint32(id) + v = &temp + case reflect.Uint64: + temp := uint64(id) + v = &temp + case reflect.Uint: + temp := uint(id) + v = &temp + } + + if tp.Kind() == reflect.Ptr { + return reflect.ValueOf(v).Convert(tp) + } + return reflect.ValueOf(v).Elem().Convert(tp) +} + +func int64ToInt(id int64, tp reflect.Type) interface{} { + return int64ToIntValue(id, tp).Interface() } func EncodeString(s string) []byte { diff --git a/convert/conversion.go b/convert/conversion.go new file mode 100644 index 0000000..16f1a92 --- /dev/null +++ b/convert/conversion.go @@ -0,0 +1,12 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +// Conversion is an interface. A type implements Conversion will according +// the custom method to fill into database and retrieve from database. +type Conversion interface { + FromDB([]byte) error + ToDB() ([]byte, error) +} diff --git a/core/db.go b/core/db.go new file mode 100644 index 0000000..5b5f03b --- /dev/null +++ b/core/db.go @@ -0,0 +1,290 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "sync" + "time" + + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" +) + +var ( + // DefaultCacheSize sets the default cache size + DefaultCacheSize = 200 +) + +func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return "", []interface{}{}, ErrNoMapPointer + } + + args := make([]interface{}, 0, len(vv.Elem().MapKeys())) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + v := vv.Elem().MapIndex(reflect.ValueOf(src[1:])) + if !v.IsValid() { + err = fmt.Errorf("map key %s is missing", src[1:]) + } else { + args = append(args, v.Interface()) + } + return "?" + }) + + return query, args, err +} + +func StructToSlice(query string, st interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return "", []interface{}{}, ErrNoStructPointer + } + + args := make([]interface{}, 0) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + fv := vv.Elem().FieldByName(src[1:]).Interface() + if v, ok := fv.(driver.Valuer); ok { + var value driver.Value + value, err = v.Value() + if err != nil { + return "?" + } + args = append(args, value) + } else { + args = append(args, fv) + } + return "?" + }) + if err != nil { + return "", []interface{}{}, err + } + return query, args, nil +} + +type cacheStruct struct { + value reflect.Value + idx int +} + +var ( + _ QueryExecuter = &DB{} +) + +// DB is a wrap of sql.DB with extra contents +type DB struct { + *sql.DB + Mapper names.Mapper + reflectCache map[reflect.Type]*cacheStruct + reflectCacheMutex sync.RWMutex + Logger log.ContextLogger +} + +// Open opens a database +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{ + DB: db, + Mapper: names.NewCacheMapper(&names.SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + }, nil +} + +// FromDB creates a DB from a sql.DB +func FromDB(db *sql.DB) *DB { + return &DB{ + DB: db, + Mapper: names.NewCacheMapper(&names.SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + } +} + +// NeedLogSQL returns true if need to log SQL +func (db *DB) NeedLogSQL(ctx context.Context) bool { + if db.Logger == nil { + return false + } + + v := ctx.Value(log.SessionShowSQLKey) + if showSQL, ok := v.(bool); ok { + return showSQL + } + return db.Logger.IsShowSQL() +} + +func (db *DB) reflectNew(typ reflect.Type) reflect.Value { + db.reflectCacheMutex.Lock() + defer db.reflectCacheMutex.Unlock() + cs, ok := db.reflectCache[typ] + if !ok || cs.idx+1 > DefaultCacheSize-1 { + cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} + db.reflectCache[typ] = cs + } else { + cs.idx = cs.idx + 1 + } + return cs.value.Index(cs.idx).Addr() +} + +// QueryContext overwrites sql.DB.QueryContext +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + start := time.Now() + showSQL := db.NeedLogSQL(ctx) + if showSQL { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + rows, err := db.DB.QueryContext(ctx, query, args...) + if showSQL { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, db}, nil +} + +// Query overwrites sql.DB.Query +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +// QueryMapContext executes query with parameters via map and context +func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +// QueryMap executes query with parameters via map +func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { + return db.QueryMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { + return db.QueryStructContext(context.Background(), query, st) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return &Row{nil, err} + } + return &Row{rows, nil} +} + +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowMap(query string, mp interface{}) *Row { + return db.QueryRowMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowStruct(query string, st interface{}) *Row { + return db.QueryRowStructContext(context.Background(), query, st) +} + +var ( + re = regexp.MustCompile(`[?](\w+)`) +) + +// ExecMapContext exec map with context.Context +// insert into (name) values (?) +// insert into (name) values (?name) +func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { + return db.ExecMapContext(context.Background(), query, mp) +} + +func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + showSQL := db.NeedLogSQL(ctx) + if showSQL { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + res, err := db.DB.ExecContext(ctx, query, args...) + if showSQL { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err +} + +func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { + return db.ExecStructContext(context.Background(), query, st) +} diff --git a/core/db_test.go b/core/db_test.go new file mode 100644 index 0000000..c554249 --- /dev/null +++ b/core/db_test.go @@ -0,0 +1,684 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "errors" + "flag" + "os" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" + "github.com/xormplus/xorm/names" +) + +var ( + dbtype = flag.String("dbtype", "sqlite3", "database type") + dbConn = flag.String("dbConn", "./db_test.db", "database connect string") + createTableSql string +) + +func TestMain(m *testing.M) { + flag.Parse() + + switch *dbtype { + case "sqlite3": + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + case "mysql": + fallthrough + default: + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + } + + exitCode := m.Run() + + os.Exit(exitCode) +} + +func testOpen() (*DB, error) { + switch *dbtype { + case "sqlite3": + os.Remove("./test.db") + return Open("sqlite3", "./test.db") + case "mysql": + return Open("mysql", *dbConn) + default: + panic("no db type") + } +} + +func BenchmarkOriQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var Id int64 + var Name, Title, Alias, NickName string + var Age float32 + var Created NullTime + err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) + if err != nil { + b.Error(err) + } + //fmt.Println(Id, Name, Title, Age, Alias, NickName) + } + rows.Close() + } +} + +type User struct { + Id int64 + Name string + Title string + Age float32 + Alias string + NickName string + Created NullTime +} + +func BenchmarkStructQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByIndex(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkStruct2Query(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + db.Mapper = names.NewCacheMapper(&names.SnakeMapper{}) + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkSliceInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]interface{}, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + b.Log(slice) + switch slice[1].(type) { + case *string: + if *slice[1].(*string) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(slice[1].([]byte)) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkSliceBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([][]byte, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if string(slice[1]) != "xlw" { + fmt.Println(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ + +func BenchmarkSliceStringQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]*string, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if (*slice[1]) != "xlw" { + b.Log(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} + +func BenchmarkMapInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]interface{}) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + switch m["name"].(type) { + case string: + if m["name"].(string) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(m["name"].([]byte)) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkMapBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string][]byte) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if string(m["name"]) != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ +/* +func BenchmarkMapStringQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]string) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if m["name"] != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +}*/ + +func BenchmarkExec(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkExecMap(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + b.Error(err) + } + } +} + +func TestExecMap(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + t.Error(err) + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("--", user) + } +} + +func TestExecStruct(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user) + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("1--", user) + } +} + +func BenchmarkExecStruct(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + b.Error(err) + } + } +} diff --git a/core/error.go b/core/error.go new file mode 100644 index 0000000..1fd1834 --- /dev/null +++ b/core/error.go @@ -0,0 +1,14 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import "errors" + +var ( + // ErrNoMapPointer represents error when no map pointer + ErrNoMapPointer = errors.New("mp should be a map's pointer") + // ErrNoStructPointer represents error when no struct pointer + ErrNoStructPointer = errors.New("mp should be a struct's pointer") +) diff --git a/core/interface.go b/core/interface.go new file mode 100644 index 0000000..a5c8e4e --- /dev/null +++ b/core/interface.go @@ -0,0 +1,22 @@ +package core + +import ( + "context" + "database/sql" +) + +// Queryer represents an interface to query a SQL to get data from database +type Queryer interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) +} + +// Executer represents an interface to execute a SQL +type Executer interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// QueryExecuter combines the Queryer and Executer +type QueryExecuter interface { + Queryer + Executer +} diff --git a/core/rows.go b/core/rows.go new file mode 100644 index 0000000..a1e8bfb --- /dev/null +++ b/core/rows.go @@ -0,0 +1,338 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql" + "errors" + "reflect" + "sync" +) + +type Rows struct { + *sql.Rows + db *DB +} + +func (rs *Rows) ToMapString() ([]map[string]string, error) { + cols, err := rs.Columns() + if err != nil { + return nil, err + } + + var results = make([]map[string]string, 0, 10) + for rs.Next() { + var record = make(map[string]string, len(cols)) + err = rs.ScanMap(&record) + if err != nil { + return nil, err + } + results = append(results, record) + } + return results, nil +} + +// scan data to a struct's pointer according field index +func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { + if len(dest) == 0 { + return errors.New("at least one struct") + } + + vvvs := make([]reflect.Value, len(dest)) + for i, s := range dest { + vv := reflect.ValueOf(s) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + vvvs[i] = vv.Elem() + } + + cols, err := rs.Columns() + if err != nil { + return err + } + newDest := make([]interface{}, len(cols)) + + var i = 0 + for _, vvv := range vvvs { + for j := 0; j < vvv.NumField(); j++ { + newDest[i] = vvv.Field(j).Addr().Interface() + i = i + 1 + } + } + + return rs.Rows.Scan(newDest...) +} + +var ( + fieldCache = make(map[reflect.Type]map[string]int) + fieldCacheMutex sync.RWMutex +) + +func fieldByName(v reflect.Value, name string) reflect.Value { + t := v.Type() + fieldCacheMutex.RLock() + cache, ok := fieldCache[t] + fieldCacheMutex.RUnlock() + if !ok { + cache = make(map[string]int) + for i := 0; i < v.NumField(); i++ { + cache[t.Field(i).Name] = i + } + fieldCacheMutex.Lock() + fieldCache[t] = cache + fieldCacheMutex.Unlock() + } + + if i, ok := cache[name]; ok { + return v.Field(i) + } + + return reflect.Zero(t) +} + +// scan data to a struct's pointer according field name +func (rs *Rows) ScanStructByName(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + var v EmptyScanner + for j, name := range cols { + f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) + if f.IsValid() { + newDest[j] = f.Addr().Interface() + } else { + newDest[j] = &v + } + } + + return rs.Rows.Scan(newDest...) +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (rs *Rows) ScanSlice(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { + return errors.New("dest should be a slice's pointer") + } + + vvv := vv.Elem() + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + + for j := 0; j < len(cols); j++ { + if j >= vvv.Len() { + newDest[j] = reflect.New(vvv.Type().Elem()).Interface() + } else { + newDest[j] = vvv.Index(j).Addr().Interface() + } + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + srcLen := vvv.Len() + for i := srcLen; i < len(cols); i++ { + vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem()) + } + return nil +} + +// scan data to a map's pointer +func (rs *Rows) ScanMap(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return errors.New("dest should be a map's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + vvv := vv.Elem() + + for i := range cols { + newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + for i, name := range cols { + vname := reflect.ValueOf(name) + vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) + } + + return nil +} + +type Row struct { + rows *Rows + // One of these two will be non-nil: + err error // deferred error for easy chaining +} + +// ErrorRow return an error row +func ErrorRow(err error) *Row { + return &Row{ + err: err, + } +} + +// NewRow from rows +func NewRow(rows *Rows, err error) *Row { + return &Row{rows, err} +} + +func (row *Row) Columns() ([]string, error) { + if row.err != nil { + return nil, row.err + } + return row.rows.Columns() +} + +func (row *Row) Scan(dest ...interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByName(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByName(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByIndex(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByIndex(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (row *Row) ScanSlice(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanSlice(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a map's pointer +func (row *Row) ScanMap(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanMap(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ToMapString() (map[string]string, error) { + cols, err := row.Columns() + if err != nil { + return nil, err + } + + var record = make(map[string]string, len(cols)) + err = row.ScanMap(&record) + if err != nil { + return nil, err + } + + return record, nil +} diff --git a/core/scan.go b/core/scan.go new file mode 100644 index 0000000..897b534 --- /dev/null +++ b/core/scan.go @@ -0,0 +1,66 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql/driver" + "fmt" + "time" +) + +type NullTime time.Time + +var ( + _ driver.Valuer = NullTime{} +) + +func (ns *NullTime) Scan(value interface{}) error { + if value == nil { + return nil + } + return convertTime(ns, value) +} + +// Value implements the driver Valuer interface. +func (ns NullTime) Value() (driver.Value, error) { + if (time.Time)(ns).IsZero() { + return nil, nil + } + return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil +} + +func convertTime(dest *NullTime, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + t, err := time.Parse("2006-01-02 15:04:05", s) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case []uint8: + t, err := time.Parse("2006-01-02 15:04:05", string(s)) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case time.Time: + *dest = NullTime(s) + return nil + case nil: + default: + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) + } + return nil +} + +type EmptyScanner struct { +} + +func (EmptyScanner) Scan(src interface{}) error { + return nil +} diff --git a/core/stmt.go b/core/stmt.go new file mode 100644 index 0000000..3586a71 --- /dev/null +++ b/core/stmt.go @@ -0,0 +1,228 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "errors" + "reflect" + "time" + + "github.com/xormplus/xorm/log" +) + +// Stmt reprents a stmt objects +type Stmt struct { + *sql.Stmt + db *DB + names map[string]int + query string +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i++ + return "?" + }) + + start := time.Now() + showSQL := db.NeedLogSQL(ctx) + if showSQL { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + }) + } + stmt, err := db.DB.PrepareContext(ctx, query) + if showSQL { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + return nil, err + } + + return &Stmt{stmt, db, names, query}, nil +} + +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + return s.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { + return s.ExecMapContext(context.Background(), mp) +} + +func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + return s.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { + return s.ExecStructContext(context.Background(), st) +} + +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + start := time.Now() + showSQL := s.db.NeedLogSQL(ctx) + if showSQL { + s.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + }) + } + res, err := s.Stmt.ExecContext(ctx, args) + if showSQL { + s.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err +} + +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + start := time.Now() + showSQL := s.db.NeedLogSQL(ctx) + if showSQL { + s.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + }) + } + rows, err := s.Stmt.QueryContext(ctx, args...) + if showSQL { + s.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + return nil, err + } + return &Rows{rows, s.db}, nil +} + +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryContext(ctx, args...) +} + +func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { + return s.QueryMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.Query(args...) +} + +func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { + return s.QueryStructContext(context.Background(), st) +} + +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) + return &Row{rows, err} +} + +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + +func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return &Row{nil, errors.New("mp should be a map's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowMap(mp interface{}) *Row { + return s.QueryRowMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return &Row{nil, errors.New("st should be a struct's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowStruct(st interface{}) *Row { + return s.QueryRowStructContext(context.Background(), st) +} diff --git a/core/tx.go b/core/tx.go new file mode 100644 index 0000000..2a99f47 --- /dev/null +++ b/core/tx.go @@ -0,0 +1,233 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "time" + + "github.com/xormplus/xorm/log" +) + +var ( + _ QueryExecuter = &Tx{} +) + +// Tx represents a transaction +type Tx struct { + *sql.Tx + db *DB +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + start := time.Now() + showSQL := db.NeedLogSQL(ctx) + if showSQL { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "BEGIN TRANSACTION", + }) + } + tx, err := db.DB.BeginTx(ctx, opts) + if showSQL { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "BEGIN TRANSACTION", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + return nil, err + } + return &Tx{tx, db}, nil +} + +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i++ + return "?" + }) + + start := time.Now() + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + }) + } + stmt, err := tx.Tx.PrepareContext(ctx, query) + if showSQL { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + return nil, err + } + return &Stmt{stmt, tx.db, names, query}, nil +} + +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) + return stmt +} + +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { + return tx.ExecMapContext(context.Background(), query, mp) +} + +func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + res, err := tx.Tx.ExecContext(ctx, query, args...) + if showSQL { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err +} + +func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { + return tx.ExecStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + start := time.Now() + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + rows, err := tx.Tx.QueryContext(ctx, query, args...) + if showSQL { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + if err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, tx.db}, nil +} + +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { + return tx.QueryMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { + return tx.QueryStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows, err} +} + +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { + return tx.QueryRowMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { + return tx.QueryRowStructContext(context.Background(), query, st) +} diff --git a/dialects/dialect.go b/dialects/dialect.go new file mode 100644 index 0000000..91c6479 --- /dev/null +++ b/dialects/dialect.go @@ -0,0 +1,284 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" +) + +// URI represents an uri to visit database +type URI struct { + DBType schemas.DBType + Proto string + Host string + Port string + DBName string + User string + Passwd string + Charset string + Laddr string + Raddr string + Timeout time.Duration + Schema string +} + +// SetSchema set schema +func (uri *URI) SetSchema(schema string) { + // hack me + if uri.DBType == schemas.POSTGRES { + uri.Schema = strings.TrimSpace(schema) + } +} + +// Dialect represents a kind of database +type Dialect interface { + Init(*URI) error + URI() *URI + SQLType(*schemas.Column) string + FormatBytes(b []byte) string + + IsReserved(string) bool + Quoter() schemas.Quoter + SetQuotePolicy(quotePolicy QuotePolicy) + + AutoIncrStr() string + + GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) + IndexCheckSQL(tableName, idxName string) (string, []interface{}) + CreateIndexSQL(tableName string, index *schemas.Index) string + DropIndexSQL(tableName string, index *schemas.Index) string + + GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) + IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) + CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) + DropTableSQL(tableName string) (string, bool) + + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) + IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) + AddColumnSQL(tableName string, col *schemas.Column) string + ModifyColumnSQL(tableName string, col *schemas.Column) string + + ForUpdateSQL(query string) string + + Filters() []Filter + SetParams(params map[string]string) +} + +// Base represents a basic dialect and all real dialects could embed this struct +type Base struct { + dialect Dialect + uri *URI + quoter schemas.Quoter +} + +func (b *Base) Quoter() schemas.Quoter { + return b.quoter +} + +func (b *Base) Init(dialect Dialect, uri *URI) error { + b.dialect, b.uri = dialect, uri + return nil +} + +func (b *Base) URI() *URI { + return b.uri +} + +func (b *Base) DBType() schemas.DBType { + return b.uri.DBType +} + +func (b *Base) FormatBytes(bs []byte) string { + return fmt.Sprintf("0x%x", bs) +} + +func (db *Base) DropTableSQL(tableName string) (string, bool) { + quote := db.dialect.Quoter().Quote + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true +} + +func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { + rows, err := queryer.QueryContext(ctx, query, args...) + if err != nil { + return false, err + } + defer rows.Close() + + if rows.Next() { + return true, nil + } + return false, nil +} + +func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + quote := db.dialect.Quoter().Quote + query := fmt.Sprintf( + "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", + quote("COLUMN_NAME"), + quote("INFORMATION_SCHEMA"), + quote("COLUMNS"), + quote("TABLE_SCHEMA"), + quote("TABLE_NAME"), + quote("COLUMN_NAME"), + ) + return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName) +} + +func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, true) + return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s) +} + +func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { + quoter := db.dialect.Quoter() + var unique string + var idxName string + if index.Type == schemas.UniqueType { + unique = " UNIQUE" + } + idxName = index.XName(tableName) + return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, + quoter.Quote(idxName), quoter.Quote(tableName), + quoter.Join(index.Cols, ",")) +} + +func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { + quote := db.dialect.Quoter().Quote + var name string + if index.IsRegular { + name = index.XName(tableName) + } else { + name = index.Name + } + return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName)) +} + +func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false) + return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, s) +} + +func (b *Base) ForUpdateSQL(query string) string { + return query + " FOR UPDATE" +} + +func (b *Base) SetParams(params map[string]string) { +} + +var ( + dialects = map[string]func() Dialect{} +) + +// RegisterDialect register database dialect +func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { + if dialectFunc == nil { + panic("core: Register dialect is nil") + } + dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect +} + +// QueryDialect query if registered database dialect +func QueryDialect(dbName schemas.DBType) Dialect { + if d, ok := dialects[strings.ToLower(string(dbName))]; ok { + return d() + } + return nil +} + +func regDrvsNDialects() bool { + providedDrvsNDialects := map[string]struct { + dbType schemas.DBType + getDriver func() Driver + getDialect func() Dialect + }{ + "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, + "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access + "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }}, + "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }}, + "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }}, + "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }}, + "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, + "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, + "goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }}, + } + + for driverName, v := range providedDrvsNDialects { + if driver := QueryDriver(driverName); driver == nil { + RegisterDriver(driverName, v.getDriver()) + RegisterDialect(v.dbType, v.getDialect) + } + } + return true +} + +func init() { + regDrvsNDialects() +} + +// ColumnString generate column description string according dialect +func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { + bd := strings.Builder{} + + if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { + return "", err + } + + if err := bd.WriteByte(' '); err != nil { + return "", err + } + + if _, err := bd.WriteString(dialect.SQLType(col)); err != nil { + return "", err + } + + if err := bd.WriteByte(' '); err != nil { + return "", err + } + + if includePrimaryKey && col.IsPrimaryKey { + if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + return "", err + } + + if col.IsAutoIncrement { + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + return "", err + } + if err := bd.WriteByte(' '); err != nil { + return "", err + } + } + } + + if col.Default != "" { + if _, err := bd.WriteString("DEFAULT "); err != nil { + return "", err + } + if _, err := bd.WriteString(col.Default); err != nil { + return "", err + } + if err := bd.WriteByte(' '); err != nil { + return "", err + } + } + + if col.Nullable { + if _, err := bd.WriteString("NULL "); err != nil { + return "", err + } + } else { + if _, err := bd.WriteString("NOT NULL "); err != nil { + return "", err + } + } + + return bd.String(), nil +} diff --git a/dialects/driver.go b/dialects/driver.go new file mode 100644 index 0000000..ae3afe4 --- /dev/null +++ b/dialects/driver.go @@ -0,0 +1,57 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" +) + +type Driver interface { + Parse(string, string) (*URI, error) +} + +var ( + drivers = map[string]Driver{} +) + +func RegisterDriver(driverName string, driver Driver) { + if driver == nil { + panic("core: Register driver is nil") + } + if _, dup := drivers[driverName]; dup { + panic("core: Register called twice for driver " + driverName) + } + drivers[driverName] = driver +} + +func QueryDriver(driverName string) Driver { + return drivers[driverName] +} + +func RegisteredDriverSize() int { + return len(drivers) +} + +// OpenDialect opens a dialect via driver name and connection string +func OpenDialect(driverName, connstr string) (Dialect, error) { + driver := QueryDriver(driverName) + if driver == nil { + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + } + + uri, err := driver.Parse(driverName, connstr) + if err != nil { + return nil, err + } + + dialect := QueryDialect(uri.DBType) + if dialect == nil { + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + } + + dialect.Init(uri) + + return dialect, nil +} diff --git a/dialects/filter.go b/dialects/filter.go new file mode 100644 index 0000000..6968b6c --- /dev/null +++ b/dialects/filter.go @@ -0,0 +1,43 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" + "strings" +) + +// Filter is an interface to filter SQL +type Filter interface { + Do(sql string) string +} + +// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... +type SeqFilter struct { + Prefix string + Start int +} + +func convertQuestionMark(sql, prefix string, start int) string { + var buf strings.Builder + var beginSingleQuote bool + var index = start + for _, c := range sql { + if !beginSingleQuote && c == '?' { + buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + index++ + } else { + if c == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteRune(c) + } + } + return buf.String() +} + +func (s *SeqFilter) Do(sql string) string { + return convertQuestionMark(sql, s.Prefix, s.Start) +} diff --git a/dialects/filter_test.go b/dialects/filter_test.go new file mode 100644 index 0000000..7e2ef0a --- /dev/null +++ b/dialects/filter_test.go @@ -0,0 +1,21 @@ +package dialects + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSeqFilter(t *testing.T) { + var kases = map[string]string{ + "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", + "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2", + "select '1''?' from issue": "select '1''?' from issue", + "select '1\\??' from issue": "select '1\\??' from issue", + "select '1\\\\',? from issue": "select '1\\\\',$1 from issue", + "select '1\\''?',? from issue": "select '1\\''?',$1 from issue", + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/dialects/gen_reserved.sh b/dialects/gen_reserved.sh new file mode 100755 index 0000000..434a1bf --- /dev/null +++ b/dialects/gen_reserved.sh @@ -0,0 +1,6 @@ +#!/bin/bash +if [ -f $1 ];then + cat $1| awk '{printf("\""$1"\":true,\n")}' +else + echo "argument $1 if not a file!" +fi diff --git a/dialect_mssql.go b/dialects/mssql.go similarity index 76% rename from dialect_mssql.go rename to dialects/mssql.go index f73a0c4..aff3225 100644 --- a/dialect_mssql.go +++ b/dialects/mssql.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "errors" "fmt" "net/url" "strconv" "strings" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) var ( @@ -202,67 +204,74 @@ var ( "EXIT": true, "PROC": true, } + + mssqlQuoter = schemas.Quoter{ + Prefix: '[', + Suffix: ']', + IsReserved: schemas.AlwaysReserve, + } ) type mssql struct { - core.Base + Base } -func (db *mssql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mssql) Init(uri *URI) error { + db.quoter = mssqlQuoter + return db.Base.Init(db, uri) } -func (db *mssql) SqlType(c *core.Column) string { +func (db *mssql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bool: - res = core.Bit + case schemas.Bool: + res = schemas.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" } else if strings.EqualFold(c.Default, "false") { c.Default = "0" } - case core.Serial: + case schemas.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.Int - case core.BigSerial: + res = schemas.Int + case schemas.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.BigInt - case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob: - res = core.VarBinary + res = schemas.BigInt + case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + res = schemas.VarBinary if c.Length == 0 { c.Length = 50 } - case core.TimeStamp: - res = core.DateTime - case core.TimeStampz: + case schemas.TimeStamp: + res = schemas.DateTime + case schemas.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 - case core.MediumInt: - res = core.Int - case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json: - res = core.Varchar + "(MAX)" - case core.Double: - res = core.Real - case core.Uuid: - res = core.Varchar + case schemas.MediumInt: + res = schemas.Int + case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json: + res = schemas.Varchar + "(MAX)" + case schemas.Double: + res = schemas.Real + case schemas.Uuid: + res = schemas.Varchar c.Length = 40 - case core.TinyInt: - res = core.TinyInt + case schemas.TinyInt: + res = schemas.TinyInt c.Length = 0 - case core.BigInt: - res = core.BigInt + case schemas.BigInt: + res = schemas.BigInt c.Length = 0 default: res = t } - if res == core.Int { - return core.Int + if res == schemas.Int { + return schemas.Int } hasLen1 := (c.Length > 0) @@ -276,88 +285,78 @@ func (db *mssql) SqlType(c *core.Column) string { return res } -func (db *mssql) SupportInsertMany() bool { - return true -} - func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mssql) Quote(name string) string { - return "[" + name + "]" -} - -func (db *mssql) SupportEngine() bool { - return false +func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mssqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mssqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mssqlQuoter + } } func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSql(tableName string) string { +func (db *mssql) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName) + "DROP TABLE \"%s\"", tableName, tableName), true } -func (db *mssql) SupportCharset() bool { - return false -} - -func (db *mssql) IndexOnTable() bool { - return true -} - -func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" return sql, args } -/*func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return sql, args -}*/ - -func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) { +func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(query, tableName, colName) + return db.HasRecords(queryer, ctx, query, tableName, colName) } -func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{} +func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" - return sql, args + return db.HasRecords(queryer, ctx, sql) } -func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(i.is_primary_key, 0), a.is_identity as is_identity + ISNULL(p.is_primary_key, 0), a.is_identity as is_identity from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id - LEFT OUTER JOIN - sys.index_columns ic ON ic.object_id = a.object_id AND ic.column_id = a.column_id - LEFT OUTER JOIN - sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + LEFT OUTER JOIN (SELECT i.object_id, ic.column_id, i.is_primary_key + FROM sys.indexes i + LEFT JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_primary_key = 1 + ) as p on p.object_id = a.object_id AND p.column_id = a.column_id where a.object_id=object_id('` + tableName + `')` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { var name, ctype, vdefault string @@ -368,7 +367,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column return nil, nil, err } - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) col.Name = strings.Trim(name, "` ") col.Nullable = nullable @@ -387,14 +386,14 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column } switch ct { case "DATETIMEOFFSET": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "NVARCHAR": - col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: 0, DefaultLength2: 0} case "IMAGE": - col.SQLType = core.SQLType{Name: core.VarBinary, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0} default: - if _, ok := core.SqlTypes[ct]; ok { - col.SQLType = core.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0} + if _, ok := schemas.SqlTypes[ct]; ok { + col.SQLType = schemas.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0} } else { return nil, nil, fmt.Errorf("Unknown colType %v for %v - %v", ct, tableName, col.Name) } @@ -406,20 +405,19 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column return colSeq, cols, nil } -func (db *mssql) GetTables() ([]*core.Table, error) { +func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -431,7 +429,7 @@ func (db *mssql) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], @@ -444,15 +442,14 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, isUnique string @@ -468,9 +465,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } if i { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } colName = strings.Trim(colName, "` ") @@ -480,10 +477,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? isRegular = true } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.Type = indexType index.Name = indexName index.IsRegular = isRegular @@ -494,7 +491,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { +func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string if tableName == "" { tableName = table.Name @@ -502,17 +499,14 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " - sql += db.Quote(tableName) + " (" + sql += db.Quoter().Quote(tableName) + " (" pkList := table.PrimaryKeys for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(db) - } else { - sql += col.StringNoPk(db) - } + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s sql = strings.TrimSpace(sql) sql += ", " } @@ -525,34 +519,21 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars sql = sql[:len(sql)-2] + ")" sql += ";" - return sql + return []string{sql}, true } -func (db *mssql) ForUpdateSql(query string) string { +func (db *mssql) ForUpdateSQL(query string) string { return query } -func (db *mssql) CreateIndexSql(tableName string, index *core.Index) string { - quote := db.Quote - var unique string - var idxName string - if index.Type == core.UniqueType { - unique = " UNIQUE" - } - idxName = index.XName(tableName) - return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, quote(",")))) -} - -func (db *mssql) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} +func (db *mssql) Filters() []Filter { + return []Filter{} } type odbcDriver struct { } -func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { var dbName string if strings.HasPrefix(dataSourceName, "sqlserver://") { @@ -576,5 +557,5 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) if dbName == "" { return nil, errors.New("no db name provided") } - return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil + return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } diff --git a/dialect_mssql_test.go b/dialects/mssql_test.go similarity index 83% rename from dialect_mssql_test.go rename to dialects/mssql_test.go index c29712d..168f177 100644 --- a/dialect_mssql_test.go +++ b/dialects/mssql_test.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "reflect" "testing" - - "github.com/xormplus/core" ) func TestParseMSSQL(t *testing.T) { @@ -21,15 +19,15 @@ func TestParseMSSQL(t *testing.T) { {"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true}, } - driver := core.QueryDriver("mssql") + driver := QueryDriver("mssql") for _, test := range tests { uri, err := driver.Parse("mssql", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } } diff --git a/dialect_mysql.go b/dialects/mysql.go similarity index 76% rename from dialect_mysql.go rename to dialects/mysql.go index 79a67fb..9f54edf 100644 --- a/dialect_mysql.go +++ b/dialects/mysql.go @@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "crypto/tls" "errors" "fmt" @@ -13,7 +14,8 @@ import ( "strings" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) var ( @@ -159,10 +161,16 @@ var ( "YEAR_MONTH": true, "ZEROFILL": true, } + + mysqlQuoter = schemas.Quoter{ + Prefix: '`', + Suffix: '`', + IsReserved: schemas.AlwaysReserve, + } ) type mysql struct { - core.Base + Base net string addr string params map[string]string @@ -175,8 +183,9 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mysql) Init(uri *URI) error { + db.quoter = mysqlQuoter + return db.Base.Init(db, uri) } func (db *mysql) SetParams(params map[string]string) { @@ -199,29 +208,29 @@ func (db *mysql) SetParams(params map[string]string) { } } -func (db *mysql) SqlType(c *core.Column) string { +func (db *mysql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bool: - res = core.TinyInt + case schemas.Bool: + res = schemas.TinyInt c.Length = 1 - case core.Serial: + case schemas.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.Int - case core.BigSerial: + res = schemas.Int + case schemas.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.BigInt - case core.Bytea: - res = core.Blob - case core.TimeStampz: - res = core.Char + res = schemas.BigInt + case schemas.Bytea: + res = schemas.Blob + case schemas.TimeStampz: + res = schemas.Char c.Length = 64 - case core.Enum: // mysql enum - res = core.Enum + case schemas.Enum: // mysql enum + res = schemas.Enum res += "(" opts := "" for v := range c.EnumOptions { @@ -229,8 +238,8 @@ func (db *mysql) SqlType(c *core.Column) string { } res += strings.TrimLeft(opts, ",") res += ")" - case core.Set: // mysql set - res = core.Set + case schemas.Set: // mysql set + res = schemas.Set res += "(" opts := "" for v := range c.SetOptions { @@ -238,13 +247,13 @@ func (db *mysql) SqlType(c *core.Column) string { } res += strings.TrimLeft(opts, ",") res += ")" - case core.NVarchar: - res = core.Varchar - case core.Uuid: - res = core.Varchar + case schemas.NVarchar: + res = schemas.Varchar + case schemas.Uuid: + res = schemas.Varchar c.Length = 40 - case core.Json: - res = core.Text + case schemas.Json: + res = schemas.Text default: res = t } @@ -252,7 +261,7 @@ func (db *mysql) SqlType(c *core.Column) string { hasLen1 := (c.Length > 0) hasLen2 := (c.Length2 > 0) - if res == core.BigInt && !hasLen1 && !hasLen2 { + if res == schemas.BigInt && !hasLen1 && !hasLen2 { c.Length = 20 hasLen1 = true } @@ -265,70 +274,52 @@ func (db *mysql) SqlType(c *core.Column) string { return res } -func (db *mysql) SupportInsertMany() bool { - return true -} - func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mysql) Quote(name string) string { - return "`" + name + "`" -} - -func (db *mysql) SupportEngine() bool { - return true -} - func (db *mysql) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *mysql) SupportCharset() bool { - return true -} - -func (db *mysql) IndexOnTable() bool { - return true -} - -func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, idxName} +func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + args := []interface{}{db.uri.DBName, tableName, idxName} sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" return sql, args } -/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args -}*/ - -func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName} +func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return sql, args + return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName) } -func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - args := []interface{}{db.DbName, tableName} +func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + s, _ := ColumnString(db, col, true) + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) + if len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + return sql +} + +func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + args := []interface{}{db.uri.DBName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var columnName, isNullable, colType, colKey, extra, comment string @@ -356,7 +347,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - if colType == core.Enum && cts[1][0] == '\'' { // enum + if colType == schemas.Enum && cts[1][0] == '\'' { // enum options := strings.Split(cts[1][0:idx], ",") col.EnumOptions = make(map[string]int) for k, v := range options { @@ -364,7 +355,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column v = strings.Trim(v, "'") col.EnumOptions[v] = k } - } else if colType == core.Set && cts[1][0] == '\'' { + } else if colType == schemas.Set && cts[1][0] == '\'' { options := strings.Split(cts[1][0:idx], ",") col.SetOptions = make(map[string]int) for k, v := range options { @@ -394,8 +385,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column } col.Length = len1 col.Length2 = len2 - if _, ok := core.SqlTypes[colType]; ok { - col.SQLType = core.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} + if _, ok := schemas.SqlTypes[colType]; ok { + col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} } else { return nil, nil, fmt.Errorf("Unknown colType %v", colType) } @@ -424,48 +415,65 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column return colSeq, cols, nil } -func (db *mysql) GetTables() ([]*core.Table, error) { - args := []interface{}{db.DbName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + +func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { + args := []interface{}{db.uri.DBName} + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() - var name, engine, tableRows, comment string - var autoIncr *string - err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment) + table := schemas.NewEmptyTable() + var name, engine string + var autoIncr, comment *string + err = rows.Scan(&name, &engine, &autoIncr, &comment) if err != nil { return nil, err } table.Name = name - table.Comment = comment + if comment != nil { + table.Comment = *comment + } table.StoreEngine = engine tables = append(tables, table) } return tables, nil } -func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { - args := []interface{}{db.DbName, tableName} +func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mysqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mysqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mysqlQuoter + } +} + +func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { + args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, nonUnique string @@ -479,9 +487,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { } if "YES" == nonUnique || nonUnique == "1" { - indexType = core.IndexType + indexType = schemas.IndexType } else { - indexType = core.UniqueType + indexType = schemas.UniqueType } colName = strings.Trim(colName, "` ") @@ -491,10 +499,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { isRegular = true } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.IsRegular = isRegular index.Type = indexType index.Name = indexName @@ -505,28 +513,15 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { return indexes, nil } -func (db *mysql) CreateIndexSql(tableName string, index *core.Index) string { - quote := db.Quote - var unique string - var idxName string - if index.Type == core.UniqueType { - unique = " UNIQUE" - } - idxName = index.XName(tableName) - return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, quote(",")))) -} - -func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { +func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name } - quotes := db.Quote("") + quoter := db.Quoter() - sql += db.Quote(tableName) + sql += quoter.Quote(tableName) sql += " (" if len(table.ColumnsSeq()) > 0 { @@ -534,11 +529,8 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(db) - } else { - sql += col.StringNoPk(db) - } + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s sql = strings.TrimSpace(sql) if len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" @@ -548,7 +540,7 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))) + sql += quoter.Join(pkList, ",") sql += " ), " } @@ -556,10 +548,11 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars } sql += ")" - if storeEngine != "" { - sql += " ENGINE=" + storeEngine + if table.StoreEngine != "" { + sql += " ENGINE=" + table.StoreEngine } + var charset = table.Charset if len(charset) == 0 { charset = db.URI().Charset } @@ -570,18 +563,18 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars if db.rowFormat != "" { sql += " ROW_FORMAT=" + db.rowFormat } - return sql + return []string{sql}, true } -func (db *mysql) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}} +func (db *mysql) Filters() []Filter { + return []Filter{} } type mymysqlDriver struct { } -func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.MYSQL} +func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + uri := &URI{DBType: schemas.MYSQL} pd := strings.SplitN(dataSourceName, "*", 2) if len(pd) == 2 { @@ -590,9 +583,9 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err if len(p) != 2 { return nil, errors.New("Wrong protocol part of URI") } - db.Proto = p[0] + uri.Proto = p[0] options := strings.Split(p[1], ",") - db.Raddr = options[0] + uri.Raddr = options[0] for _, o := range options[1:] { kv := strings.SplitN(o, "=", 2) var k, v string @@ -603,13 +596,13 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err } switch k { case "laddr": - db.Laddr = v + uri.Laddr = v case "timeout": to, err := time.ParseDuration(v) if err != nil { return nil, err } - db.Timeout = to + uri.Timeout = to default: return nil, errors.New("Unknown option: " + k) } @@ -622,17 +615,17 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err if len(dup) != 3 { return nil, errors.New("Wrong database part of URI") } - db.DbName = dup[0] - db.User = dup[1] - db.Passwd = dup[2] + uri.DBName = dup[0] + uri.User = dup[1] + uri.Passwd = dup[2] - return db, nil + return uri, nil } type mysqlDriver struct { } -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] @@ -642,12 +635,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error // tlsConfigRegister := make(map[string]*tls.Config) names := dsnPattern.SubexpNames() - uri := &core.Uri{DbType: core.MYSQL} + uri := &URI{DBType: schemas.MYSQL} for i, match := range matches { switch names[i] { case "dbname": - uri.DbName = match + uri.DBName = match case "params": if len(match) > 0 { kvs := strings.Split(match, "&") diff --git a/dialect_oracle.go b/dialects/oracle.go similarity index 79% rename from dialect_oracle.go rename to dialects/oracle.go index fb57324..bb3c69c 100644 --- a/dialect_oracle.go +++ b/dialects/oracle.go @@ -2,19 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( - "crypto/md5" - "encoding/base64" + "context" "errors" "fmt" - "io" "regexp" "strconv" "strings" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) var ( @@ -499,32 +498,39 @@ var ( "YEAR": true, "ZONE": true, } + + oracleQuoter = schemas.Quoter{ + Prefix: '[', + Suffix: ']', + IsReserved: schemas.AlwaysReserve, + } ) type oracle struct { - core.Base + Base } -func (db *oracle) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *oracle) Init(uri *URI) error { + db.quoter = oracleQuoter + return db.Base.Init(db, uri) } -func (db *oracle) SqlType(c *core.Column) string { +func (db *oracle) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: + case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial: res = "NUMBER" - case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: - return core.Blob - case core.Time, core.DateTime, core.TimeStamp: - res = core.TimeStamp - case core.TimeStampz: + case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: + return schemas.Blob + case schemas.Time, schemas.DateTime, schemas.TimeStamp: + res = schemas.TimeStamp + case schemas.TimeStampz: res = "TIMESTAMP WITH TIME ZONE" - case core.Float, core.Double, core.Numeric, core.Decimal: + case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: res = "NUMBER" - case core.Text, core.MediumText, core.LongText, core.Json: + case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: res = "CLOB" - case core.Char, core.Varchar, core.TinyText: + case schemas.Char, schemas.Varchar, schemas.TinyText: res = "VARCHAR2" default: res = t @@ -545,46 +551,23 @@ func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *oracle) SupportInsertMany() bool { - return true -} - func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } -func (db *oracle) Quote(name string) string { - return "[" + name + "]" +func (db *oracle) DropTableSQL(tableName string) (string, bool) { + return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) SupportEngine() bool { - return false -} - -func (db *oracle) SupportCharset() bool { - return false -} - -func (db *oracle) SupportDropIfExists() bool { - return false -} - -func (db *oracle) IndexOnTable() bool { - return false -} - -func (db *oracle) DropTableSql(tableName string) string { - return fmt.Sprintf("DROP TABLE `%s`", tableName) -} - -func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { +func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } - sql += db.Quote(tableName) + " (" + quoter := db.Quoter() + sql += quoter.Quote(tableName) + " (" pkList := table.PrimaryKeys @@ -593,107 +576,72 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - sql += col.StringNoPk(db) + s, _ := ColumnString(db, col, false) + sql += s // } sql = strings.TrimSpace(sql) sql += ", " } - quotes := db.Quote("") - if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))) + sql += quoter.Join(pkList, ",") sql += " ), " } sql = sql[:len(sql)-2] + ")" - if db.SupportEngine() && storeEngine != "" { - sql += " ENGINE=" + storeEngine - } - if db.SupportCharset() { - if len(charset) == 0 { - charset = db.URI().Charset - } - if len(charset) > 0 { - sql += " DEFAULT CHARSET " + charset - } + return []string{sql}, false +} + +func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = oracleQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = oracleQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = oracleQuoter } - return sql } -func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{tableName, idxName} return `SELECT INDEX_NAME FROM USER_INDEXES ` + `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } -func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT table_name FROM user_tables WHERE table_name = :1`, args +func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) } -func (db *oracle) MustDropTable(tableName string) error { - sql, args := db.TableCheckSql(tableName) - db.LogSQL(sql, args) - - rows, err := db.DB().Query(sql, args...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - return nil - } - - sql = "Drop Table \"" + tableName + "\"" - db.LogSQL(sql, args) - - _, err = db.DB().Exec(sql) - return err -} - -/*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} - return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - -func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) { +func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - db.LogSQL(query, args) - - rows, err := db.DB().Query(query, args...) - if err != nil { - return false, err - } - defer rows.Close() - - if rows.Next() { - return true, nil - } - return false, nil + return db.HasRecords(queryer, ctx, query, args...) } -func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string @@ -735,30 +683,30 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum switch dt { case "VARCHAR2": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2} case "NVARCHAR2": - col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: len1, DefaultLength2: len2} case "TIMESTAMP WITH TIME ZONE": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "NUMBER": - col.SQLType = core.SQLType{Name: core.Double, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2} case "LONG", "LONG RAW": - col.SQLType = core.SQLType{Name: core.Text, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0} case "RAW": - col.SQLType = core.SQLType{Name: core.Binary, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} case "ROWID": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 18, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0} case "AQ$_SUBSCRIBERS": ignore = true default: - col.SQLType = core.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} } if ignore { continue } - if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType) } @@ -776,20 +724,19 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum return colSeq, cols, nil } -func (db *oracle) GetTables() ([]*core.Table, error) { +func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -800,19 +747,18 @@ func (db *oracle) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, uniqueness string @@ -831,15 +777,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { } if uniqueness == "UNIQUE" { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.Type = indexType index.Name = indexName index.IsRegular = isRegular @@ -850,78 +796,17 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { return indexes, nil } -func oracle_hash(str string) string { - - if len(str) > 26 { - h := md5.New() - io.WriteString(h, str) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) +func (db *oracle) Filters() []Filter { + return []Filter{ + &SeqFilter{Prefix: ":", Start: 1}, } - - return str -} - -func oracle_index_name(index *core.Index, tableName string) string { - - if !strings.HasPrefix(index.Name, "UQE_") && - !strings.HasPrefix(index.Name, "IDX_") { - - name := oracle_hash(fmt.Sprintf("%v_%v", tableName, index.Name)) - - if index.Type == core.UniqueType { - return fmt.Sprintf("UQE_%v", name) - } - - return fmt.Sprintf("IDX_%v", name) - } - - return index.Name -} - -func (db *oracle) CreateIndexSql(tableName string, index *core.Index) string { - - quote := db.Quote - var unique string - var idxName string - - if index.Type == core.UniqueType { - unique = " UNIQUE" - } - - idxName = oracle_index_name(index, tableName) - - return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, quote(",")))) -} - -func (db *oracle) DropIndexSql(tableName string, index *core.Index) string { - - if strings.HasPrefix(index.Name, "SYS_") { - return "" - } - - quote := db.Quote - var name string - - if index.IsRegular { - name = oracle_index_name(index, tableName) - } else { - name = index.Name - } - - return fmt.Sprintf("DROP INDEX %v", quote(name)) -} - -func (db *oracle) Filters() []core.Filter { - return []core.Filter{&core.QuoteFilter{}, &core.SeqFilter{Prefix: ":", Start: 1}, &core.IdFilter{}} } type goracleDriver struct { } -func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} +func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] @@ -934,10 +819,10 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, e for i, match := range matches { switch names[i] { case "dbname": - db.DbName = match + db.DBName = match } } - if db.DbName == "" { + if db.DBName == "" { return nil, errors.New("dbname is empty") } return db, nil @@ -948,8 +833,8 @@ type oci8Driver struct { // dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} +func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?P.*)\/(?P.*)@` + // user:password@ `(?P.*)` + // ip:port @@ -959,10 +844,10 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) for i, match := range matches { switch names[i] { case "dbname": - db.DbName = match + db.DBName = match } } - if db.DbName == "" { + if db.DBName == "" && len(matches) != 0 { return nil, errors.New("dbname is empty") } return db, nil diff --git a/dialects/oracle_test.go b/dialects/oracle_test.go new file mode 100644 index 0000000..9c3a93f --- /dev/null +++ b/dialects/oracle_test.go @@ -0,0 +1,34 @@ +package dialects + +import ( + "reflect" + "testing" +) + +func TestParseOracleConnStr(t *testing.T) { + tests := []struct { + in string + expected string + valid bool + }{ + {"user/pass@tcp(server:1521)/db", "db", true}, + {"user/pass@server:1521/db", "db", true}, + // test for net service name : https://docs.oracle.com/cd/B13789_01/network.101/b10775/glossary.htm#i998113 + {"user/pass@server:1521", "", true}, + {"user/pass@", "", false}, + {"user/pass", "", false}, + {"", "", false}, + } + driver := QueryDriver("oci8") + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + driver := driver + uri, err := driver.Parse("oci8", test.in) + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) + } + }) + } +} diff --git a/dialects/pg_reserved.txt b/dialects/pg_reserved.txt new file mode 100644 index 0000000..720ed37 --- /dev/null +++ b/dialects/pg_reserved.txt @@ -0,0 +1,746 @@ +A non-reserved non-reserved +ABORT non-reserved +ABS reserved reserved +ABSENT non-reserved non-reserved +ABSOLUTE non-reserved non-reserved non-reserved reserved +ACCESS non-reserved +ACCORDING non-reserved non-reserved +ACTION non-reserved non-reserved non-reserved reserved +ADA non-reserved non-reserved non-reserved +ADD non-reserved non-reserved non-reserved reserved +ADMIN non-reserved non-reserved non-reserved +AFTER non-reserved non-reserved non-reserved +AGGREGATE non-reserved +ALL reserved reserved reserved reserved +ALLOCATE reserved reserved reserved +ALSO non-reserved +ALTER non-reserved reserved reserved reserved +ALWAYS non-reserved non-reserved non-reserved +ANALYSE reserved +ANALYZE reserved +AND reserved reserved reserved reserved +ANY reserved reserved reserved reserved +ARE reserved reserved reserved +ARRAY reserved reserved reserved +ARRAY_AGG reserved reserved +ARRAY_MAX_CARDINALITY reserved +AS reserved reserved reserved reserved +ASC reserved non-reserved non-reserved reserved +ASENSITIVE reserved reserved +ASSERTION non-reserved non-reserved non-reserved reserved +ASSIGNMENT non-reserved non-reserved non-reserved +ASYMMETRIC reserved reserved reserved +AT non-reserved reserved reserved reserved +ATOMIC reserved reserved +ATTRIBUTE non-reserved non-reserved non-reserved +ATTRIBUTES non-reserved non-reserved +AUTHORIZATION reserved (can be function or type) reserved reserved reserved +AVG reserved reserved reserved +BACKWARD non-reserved +BASE64 non-reserved non-reserved +BEFORE non-reserved non-reserved non-reserved +BEGIN non-reserved reserved reserved reserved +BEGIN_FRAME reserved +BEGIN_PARTITION reserved +BERNOULLI non-reserved non-reserved +BETWEEN non-reserved (cannot be function or type) reserved reserved reserved +BIGINT non-reserved (cannot be function or type) reserved reserved +BINARY reserved (can be function or type) reserved reserved +BIT non-reserved (cannot be function or type) reserved +BIT_LENGTH reserved +BLOB reserved reserved +BLOCKED non-reserved non-reserved +BOM non-reserved non-reserved +BOOLEAN non-reserved (cannot be function or type) reserved reserved +BOTH reserved reserved reserved reserved +BREADTH non-reserved non-reserved +BY non-reserved reserved reserved reserved +C non-reserved non-reserved non-reserved +CACHE non-reserved +CALL reserved reserved +CALLED non-reserved reserved reserved +CARDINALITY reserved reserved +CASCADE non-reserved non-reserved non-reserved reserved +CASCADED non-reserved reserved reserved reserved +CASE reserved reserved reserved reserved +CAST reserved reserved reserved reserved +CATALOG non-reserved non-reserved non-reserved reserved +CATALOG_NAME non-reserved non-reserved non-reserved +CEIL reserved reserved +CEILING reserved reserved +CHAIN non-reserved non-reserved non-reserved +CHAR non-reserved (cannot be function or type) reserved reserved reserved +CHARACTER non-reserved (cannot be function or type) reserved reserved reserved +CHARACTERISTICS non-reserved non-reserved non-reserved +CHARACTERS non-reserved non-reserved +CHARACTER_LENGTH reserved reserved reserved +CHARACTER_SET_CATALOG non-reserved non-reserved non-reserved +CHARACTER_SET_NAME non-reserved non-reserved non-reserved +CHARACTER_SET_SCHEMA non-reserved non-reserved non-reserved +CHAR_LENGTH reserved reserved reserved +CHECK reserved reserved reserved reserved +CHECKPOINT non-reserved +CLASS non-reserved +CLASS_ORIGIN non-reserved non-reserved non-reserved +CLOB reserved reserved +CLOSE non-reserved reserved reserved reserved +CLUSTER non-reserved +COALESCE non-reserved (cannot be function or type) reserved reserved reserved +COBOL non-reserved non-reserved non-reserved +COLLATE reserved reserved reserved reserved +COLLATION reserved (can be function or type) non-reserved non-reserved reserved +COLLATION_CATALOG non-reserved non-reserved non-reserved +COLLATION_NAME non-reserved non-reserved non-reserved +COLLATION_SCHEMA non-reserved non-reserved non-reserved +COLLECT reserved reserved +COLUMN reserved reserved reserved reserved +COLUMNS non-reserved non-reserved +COLUMN_NAME non-reserved non-reserved non-reserved +COMMAND_FUNCTION non-reserved non-reserved non-reserved +COMMAND_FUNCTION_CODE non-reserved non-reserved +COMMENT non-reserved +COMMENTS non-reserved +COMMIT non-reserved reserved reserved reserved +COMMITTED non-reserved non-reserved non-reserved non-reserved +CONCURRENTLY reserved (can be function or type) +CONDITION reserved reserved +CONDITION_NUMBER non-reserved non-reserved non-reserved +CONFIGURATION non-reserved +CONNECT reserved reserved reserved +CONNECTION non-reserved non-reserved non-reserved reserved +CONNECTION_NAME non-reserved non-reserved non-reserved +CONSTRAINT reserved reserved reserved reserved +CONSTRAINTS non-reserved non-reserved non-reserved reserved +CONSTRAINT_CATALOG non-reserved non-reserved non-reserved +CONSTRAINT_NAME non-reserved non-reserved non-reserved +CONSTRAINT_SCHEMA non-reserved non-reserved non-reserved +CONSTRUCTOR non-reserved non-reserved +CONTAINS reserved non-reserved +CONTENT non-reserved non-reserved non-reserved +CONTINUE non-reserved non-reserved non-reserved reserved +CONTROL non-reserved non-reserved +CONVERSION non-reserved +CONVERT reserved reserved reserved +COPY non-reserved +CORR reserved reserved +CORRESPONDING reserved reserved reserved +COST non-reserved +COUNT reserved reserved reserved +COVAR_POP reserved reserved +COVAR_SAMP reserved reserved +CREATE reserved reserved reserved reserved +CROSS reserved (can be function or type) reserved reserved reserved +CSV non-reserved +CUBE reserved reserved +CUME_DIST reserved reserved +CURRENT non-reserved reserved reserved reserved +CURRENT_CATALOG reserved reserved reserved +CURRENT_DATE reserved reserved reserved reserved +CURRENT_DEFAULT_TRANSFORM_GROUP reserved reserved +CURRENT_PATH reserved reserved +CURRENT_ROLE reserved reserved reserved +CURRENT_ROW reserved +CURRENT_SCHEMA reserved (can be function or type) reserved reserved +CURRENT_TIME reserved reserved reserved reserved +CURRENT_TIMESTAMP reserved reserved reserved reserved +CURRENT_TRANSFORM_GROUP_FOR_TYPE reserved reserved +CURRENT_USER reserved reserved reserved reserved +CURSOR non-reserved reserved reserved reserved +CURSOR_NAME non-reserved non-reserved non-reserved +CYCLE non-reserved reserved reserved +DATA non-reserved non-reserved non-reserved non-reserved +DATABASE non-reserved +DATALINK reserved reserved +DATE reserved reserved reserved +DATETIME_INTERVAL_CODE non-reserved non-reserved non-reserved +DATETIME_INTERVAL_PRECISION non-reserved non-reserved non-reserved +DAY non-reserved reserved reserved reserved +DB non-reserved non-reserved +DEALLOCATE non-reserved reserved reserved reserved +DEC non-reserved (cannot be function or type) reserved reserved reserved +DECIMAL non-reserved (cannot be function or type) reserved reserved reserved +DECLARE non-reserved reserved reserved reserved +DEFAULT reserved reserved reserved reserved +DEFAULTS non-reserved non-reserved non-reserved +DEFERRABLE reserved non-reserved non-reserved reserved +DEFERRED non-reserved non-reserved non-reserved reserved +DEFINED non-reserved non-reserved +DEFINER non-reserved non-reserved non-reserved +DEGREE non-reserved non-reserved +DELETE non-reserved reserved reserved reserved +DELIMITER non-reserved +DELIMITERS non-reserved +DENSE_RANK reserved reserved +DEPTH non-reserved non-reserved +DEREF reserved reserved +DERIVED non-reserved non-reserved +DESC reserved non-reserved non-reserved reserved +DESCRIBE reserved reserved reserved +DESCRIPTOR non-reserved non-reserved reserved +DETERMINISTIC reserved reserved +DIAGNOSTICS non-reserved non-reserved reserved +DICTIONARY non-reserved +DISABLE non-reserved +DISCARD non-reserved +DISCONNECT reserved reserved reserved +DISPATCH non-reserved non-reserved +DISTINCT reserved reserved reserved reserved +DLNEWCOPY reserved reserved +DLPREVIOUSCOPY reserved reserved +DLURLCOMPLETE reserved reserved +DLURLCOMPLETEONLY reserved reserved +DLURLCOMPLETEWRITE reserved reserved +DLURLPATH reserved reserved +DLURLPATHONLY reserved reserved +DLURLPATHWRITE reserved reserved +DLURLSCHEME reserved reserved +DLURLSERVER reserved reserved +DLVALUE reserved reserved +DO reserved +DOCUMENT non-reserved non-reserved non-reserved +DOMAIN non-reserved non-reserved non-reserved reserved +DOUBLE non-reserved reserved reserved reserved +DROP non-reserved reserved reserved reserved +DYNAMIC reserved reserved +DYNAMIC_FUNCTION non-reserved non-reserved non-reserved +DYNAMIC_FUNCTION_CODE non-reserved non-reserved +EACH non-reserved reserved reserved +ELEMENT reserved reserved +ELSE reserved reserved reserved reserved +EMPTY non-reserved non-reserved +ENABLE non-reserved +ENCODING non-reserved non-reserved non-reserved +ENCRYPTED non-reserved +END reserved reserved reserved reserved +END-EXEC reserved reserved reserved +END_FRAME reserved +END_PARTITION reserved +ENFORCED non-reserved +ENUM non-reserved +EQUALS reserved non-reserved +ESCAPE non-reserved reserved reserved reserved +EVENT non-reserved +EVERY reserved reserved +EXCEPT reserved reserved reserved reserved +EXCEPTION reserved +EXCLUDE non-reserved non-reserved non-reserved +EXCLUDING non-reserved non-reserved non-reserved +EXCLUSIVE non-reserved +EXEC reserved reserved reserved +EXECUTE non-reserved reserved reserved reserved +EXISTS non-reserved (cannot be function or type) reserved reserved reserved +EXP reserved reserved +EXPLAIN non-reserved +EXPRESSION non-reserved +EXTENSION non-reserved +EXTERNAL non-reserved reserved reserved reserved +EXTRACT non-reserved (cannot be function or type) reserved reserved reserved +FALSE reserved reserved reserved reserved +FAMILY non-reserved +FETCH reserved reserved reserved reserved +FILE non-reserved non-reserved +FILTER reserved reserved +FINAL non-reserved non-reserved +FIRST non-reserved non-reserved non-reserved reserved +FIRST_VALUE reserved reserved +FLAG non-reserved non-reserved +FLOAT non-reserved (cannot be function or type) reserved reserved reserved +FLOOR reserved reserved +FOLLOWING non-reserved non-reserved non-reserved +FOR reserved reserved reserved reserved +FORCE non-reserved +FOREIGN reserved reserved reserved reserved +FORTRAN non-reserved non-reserved non-reserved +FORWARD non-reserved +FOUND non-reserved non-reserved reserved +FRAME_ROW reserved +FREE reserved reserved +FREEZE reserved (can be function or type) +FROM reserved reserved reserved reserved +FS non-reserved non-reserved +FULL reserved (can be function or type) reserved reserved reserved +FUNCTION non-reserved reserved reserved +FUNCTIONS non-reserved +FUSION reserved reserved +G non-reserved non-reserved +GENERAL non-reserved non-reserved +GENERATED non-reserved non-reserved +GET reserved reserved reserved +GLOBAL non-reserved reserved reserved reserved +GO non-reserved non-reserved reserved +GOTO non-reserved non-reserved reserved +GRANT reserved reserved reserved reserved +GRANTED non-reserved non-reserved non-reserved +GREATEST non-reserved (cannot be function or type) +GROUP reserved reserved reserved reserved +GROUPING reserved reserved +GROUPS reserved +HANDLER non-reserved +HAVING reserved reserved reserved reserved +HEADER non-reserved +HEX non-reserved non-reserved +HIERARCHY non-reserved non-reserved +HOLD non-reserved reserved reserved +HOUR non-reserved reserved reserved reserved +ID non-reserved non-reserved +IDENTITY non-reserved reserved reserved reserved +IF non-reserved +IGNORE non-reserved non-reserved +ILIKE reserved (can be function or type) +IMMEDIATE non-reserved non-reserved non-reserved reserved +IMMEDIATELY non-reserved +IMMUTABLE non-reserved +IMPLEMENTATION non-reserved non-reserved +IMPLICIT non-reserved +IMPORT reserved reserved +IN reserved reserved reserved reserved +INCLUDING non-reserved non-reserved non-reserved +INCREMENT non-reserved non-reserved non-reserved +INDENT non-reserved non-reserved +INDEX non-reserved +INDEXES non-reserved +INDICATOR reserved reserved reserved +INHERIT non-reserved +INHERITS non-reserved +INITIALLY reserved non-reserved non-reserved reserved +INLINE non-reserved +INNER reserved (can be function or type) reserved reserved reserved +INOUT non-reserved (cannot be function or type) reserved reserved +INPUT non-reserved non-reserved non-reserved reserved +INSENSITIVE non-reserved reserved reserved reserved +INSERT non-reserved reserved reserved reserved +INSTANCE non-reserved non-reserved +INSTANTIABLE non-reserved non-reserved +INSTEAD non-reserved non-reserved non-reserved +INT non-reserved (cannot be function or type) reserved reserved reserved +INTEGER non-reserved (cannot be function or type) reserved reserved reserved +INTEGRITY non-reserved non-reserved +INTERSECT reserved reserved reserved reserved +INTERSECTION reserved reserved +INTERVAL non-reserved (cannot be function or type) reserved reserved reserved +INTO reserved reserved reserved reserved +INVOKER non-reserved non-reserved non-reserved +IS reserved (can be function or type) reserved reserved reserved +ISNULL reserved (can be function or type) +ISOLATION non-reserved non-reserved non-reserved reserved +JOIN reserved (can be function or type) reserved reserved reserved +K non-reserved non-reserved +KEY non-reserved non-reserved non-reserved reserved +KEY_MEMBER non-reserved non-reserved +KEY_TYPE non-reserved non-reserved +LABEL non-reserved +LAG reserved reserved +LANGUAGE non-reserved reserved reserved reserved +LARGE non-reserved reserved reserved +LAST non-reserved non-reserved non-reserved reserved +LAST_VALUE reserved reserved +LATERAL reserved reserved reserved +LC_COLLATE non-reserved +LC_CTYPE non-reserved +LEAD reserved reserved +LEADING reserved reserved reserved reserved +LEAKPROOF non-reserved +LEAST non-reserved (cannot be function or type) +LEFT reserved (can be function or type) reserved reserved reserved +LENGTH non-reserved non-reserved non-reserved +LEVEL non-reserved non-reserved non-reserved reserved +LIBRARY non-reserved non-reserved +LIKE reserved (can be function or type) reserved reserved reserved +LIKE_REGEX reserved reserved +LIMIT reserved non-reserved non-reserved +LINK non-reserved non-reserved +LISTEN non-reserved +LN reserved reserved +LOAD non-reserved +LOCAL non-reserved reserved reserved reserved +LOCALTIME reserved reserved reserved +LOCALTIMESTAMP reserved reserved reserved +LOCATION non-reserved non-reserved non-reserved +LOCATOR non-reserved non-reserved +LOCK non-reserved +LOWER reserved reserved reserved +M non-reserved non-reserved +MAP non-reserved non-reserved +MAPPING non-reserved non-reserved non-reserved +MATCH non-reserved reserved reserved reserved +MATCHED non-reserved non-reserved +MATERIALIZED non-reserved +MAX reserved reserved reserved +MAXVALUE non-reserved non-reserved non-reserved +MAX_CARDINALITY reserved +MEMBER reserved reserved +MERGE reserved reserved +MESSAGE_LENGTH non-reserved non-reserved non-reserved +MESSAGE_OCTET_LENGTH non-reserved non-reserved non-reserved +MESSAGE_TEXT non-reserved non-reserved non-reserved +METHOD reserved reserved +MIN reserved reserved reserved +MINUTE non-reserved reserved reserved reserved +MINVALUE non-reserved non-reserved non-reserved +MOD reserved reserved +MODE non-reserved +MODIFIES reserved reserved +MODULE reserved reserved reserved +MONTH non-reserved reserved reserved reserved +MORE non-reserved non-reserved non-reserved +MOVE non-reserved +MULTISET reserved reserved +MUMPS non-reserved non-reserved non-reserved +NAME non-reserved non-reserved non-reserved non-reserved +NAMES non-reserved non-reserved non-reserved reserved +NAMESPACE non-reserved non-reserved +NATIONAL non-reserved (cannot be function or type) reserved reserved reserved +NATURAL reserved (can be function or type) reserved reserved reserved +NCHAR non-reserved (cannot be function or type) reserved reserved reserved +NCLOB reserved reserved +NESTING non-reserved non-reserved +NEW reserved reserved +NEXT non-reserved non-reserved non-reserved reserved +NFC non-reserved non-reserved +NFD non-reserved non-reserved +NFKC non-reserved non-reserved +NFKD non-reserved non-reserved +NIL non-reserved non-reserved +NO non-reserved reserved reserved reserved +NONE non-reserved (cannot be function or type) reserved reserved +NORMALIZE reserved reserved +NORMALIZED non-reserved non-reserved +NOT reserved reserved reserved reserved +NOTHING non-reserved +NOTIFY non-reserved +NOTNULL reserved (can be function or type) +NOWAIT non-reserved +NTH_VALUE reserved reserved +NTILE reserved reserved +NULL reserved reserved reserved reserved +NULLABLE non-reserved non-reserved non-reserved +NULLIF non-reserved (cannot be function or type) reserved reserved reserved +NULLS non-reserved non-reserved non-reserved +NUMBER non-reserved non-reserved non-reserved +NUMERIC non-reserved (cannot be function or type) reserved reserved reserved +OBJECT non-reserved non-reserved non-reserved +OCCURRENCES_REGEX reserved reserved +OCTETS non-reserved non-reserved +OCTET_LENGTH reserved reserved reserved +OF non-reserved reserved reserved reserved +OFF non-reserved non-reserved non-reserved +OFFSET reserved reserved reserved +OIDS non-reserved +OLD reserved reserved +ON reserved reserved reserved reserved +ONLY reserved reserved reserved reserved +OPEN reserved reserved reserved +OPERATOR non-reserved +OPTION non-reserved non-reserved non-reserved reserved +OPTIONS non-reserved non-reserved non-reserved +OR reserved reserved reserved reserved +ORDER reserved reserved reserved reserved +ORDERING non-reserved non-reserved +ORDINALITY non-reserved non-reserved +OTHERS non-reserved non-reserved +OUT non-reserved (cannot be function or type) reserved reserved +OUTER reserved (can be function or type) reserved reserved reserved +OUTPUT non-reserved non-reserved reserved +OVER reserved (can be function or type) reserved reserved +OVERLAPS reserved (can be function or type) reserved reserved reserved +OVERLAY non-reserved (cannot be function or type) reserved reserved +OVERRIDING non-reserved non-reserved +OWNED non-reserved +OWNER non-reserved +P non-reserved non-reserved +PAD non-reserved non-reserved reserved +PARAMETER reserved reserved +PARAMETER_MODE non-reserved non-reserved +PARAMETER_NAME non-reserved non-reserved +PARAMETER_ORDINAL_POSITION non-reserved non-reserved +PARAMETER_SPECIFIC_CATALOG non-reserved non-reserved +PARAMETER_SPECIFIC_NAME non-reserved non-reserved +PARAMETER_SPECIFIC_SCHEMA non-reserved non-reserved +PARSER non-reserved +PARTIAL non-reserved non-reserved non-reserved reserved +PARTITION non-reserved reserved reserved +PASCAL non-reserved non-reserved non-reserved +PASSING non-reserved non-reserved non-reserved +PASSTHROUGH non-reserved non-reserved +PASSWORD non-reserved +PATH non-reserved non-reserved +PERCENT reserved +PERCENTILE_CONT reserved reserved +PERCENTILE_DISC reserved reserved +PERCENT_RANK reserved reserved +PERIOD reserved +PERMISSION non-reserved non-reserved +PLACING reserved non-reserved non-reserved +PLANS non-reserved +PLI non-reserved non-reserved non-reserved +PORTION reserved +POSITION non-reserved (cannot be function or type) reserved reserved reserved +POSITION_REGEX reserved reserved +POWER reserved reserved +PRECEDES reserved +PRECEDING non-reserved non-reserved non-reserved +PRECISION non-reserved (cannot be function or type) reserved reserved reserved +PREPARE non-reserved reserved reserved reserved +PREPARED non-reserved +PRESERVE non-reserved non-reserved non-reserved reserved +PRIMARY reserved reserved reserved reserved +PRIOR non-reserved non-reserved non-reserved reserved +PRIVILEGES non-reserved non-reserved non-reserved reserved +PROCEDURAL non-reserved +PROCEDURE non-reserved reserved reserved reserved +PROGRAM non-reserved +PUBLIC non-reserved non-reserved reserved +QUOTE non-reserved +RANGE non-reserved reserved reserved +RANK reserved reserved +READ non-reserved non-reserved non-reserved reserved +READS reserved reserved +REAL non-reserved (cannot be function or type) reserved reserved reserved +REASSIGN non-reserved +RECHECK non-reserved +RECOVERY non-reserved non-reserved +RECURSIVE non-reserved reserved reserved +REF non-reserved reserved reserved +REFERENCES reserved reserved reserved reserved +REFERENCING reserved reserved +REFRESH non-reserved +REGR_AVGX reserved reserved +REGR_AVGY reserved reserved +REGR_COUNT reserved reserved +REGR_INTERCEPT reserved reserved +REGR_R2 reserved reserved +REGR_SLOPE reserved reserved +REGR_SXX reserved reserved +REGR_SXY reserved reserved +REGR_SYY reserved reserved +REINDEX non-reserved +RELATIVE non-reserved non-reserved non-reserved reserved +RELEASE non-reserved reserved reserved +RENAME non-reserved +REPEATABLE non-reserved non-reserved non-reserved non-reserved +REPLACE non-reserved +REPLICA non-reserved +REQUIRING non-reserved non-reserved +RESET non-reserved +RESPECT non-reserved non-reserved +RESTART non-reserved non-reserved non-reserved +RESTORE non-reserved non-reserved +RESTRICT non-reserved non-reserved non-reserved reserved +RESULT reserved reserved +RETURN reserved reserved +RETURNED_CARDINALITY non-reserved non-reserved +RETURNED_LENGTH non-reserved non-reserved non-reserved +RETURNED_OCTET_LENGTH non-reserved non-reserved non-reserved +RETURNED_SQLSTATE non-reserved non-reserved non-reserved +RETURNING reserved non-reserved non-reserved +RETURNS non-reserved reserved reserved +REVOKE non-reserved reserved reserved reserved +RIGHT reserved (can be function or type) reserved reserved reserved +ROLE non-reserved non-reserved non-reserved +ROLLBACK non-reserved reserved reserved reserved +ROLLUP reserved reserved +ROUTINE non-reserved non-reserved +ROUTINE_CATALOG non-reserved non-reserved +ROUTINE_NAME non-reserved non-reserved +ROUTINE_SCHEMA non-reserved non-reserved +ROW non-reserved (cannot be function or type) reserved reserved +ROWS non-reserved reserved reserved reserved +ROW_COUNT non-reserved non-reserved non-reserved +ROW_NUMBER reserved reserved +RULE non-reserved +SAVEPOINT non-reserved reserved reserved +SCALE non-reserved non-reserved non-reserved +SCHEMA non-reserved non-reserved non-reserved reserved +SCHEMA_NAME non-reserved non-reserved non-reserved +SCOPE reserved reserved +SCOPE_CATALOG non-reserved non-reserved +SCOPE_NAME non-reserved non-reserved +SCOPE_SCHEMA non-reserved non-reserved +SCROLL non-reserved reserved reserved reserved +SEARCH non-reserved reserved reserved +SECOND non-reserved reserved reserved reserved +SECTION non-reserved non-reserved reserved +SECURITY non-reserved non-reserved non-reserved +SELECT reserved reserved reserved reserved +SELECTIVE non-reserved non-reserved +SELF non-reserved non-reserved +SENSITIVE reserved reserved +SEQUENCE non-reserved non-reserved non-reserved +SEQUENCES non-reserved +SERIALIZABLE non-reserved non-reserved non-reserved non-reserved +SERVER non-reserved non-reserved non-reserved +SERVER_NAME non-reserved non-reserved non-reserved +SESSION non-reserved non-reserved non-reserved reserved +SESSION_USER reserved reserved reserved reserved +SET non-reserved reserved reserved reserved +SETOF non-reserved (cannot be function or type) +SETS non-reserved non-reserved +SHARE non-reserved +SHOW non-reserved +SIMILAR reserved (can be function or type) reserved reserved +SIMPLE non-reserved non-reserved non-reserved +SIZE non-reserved non-reserved reserved +SMALLINT non-reserved (cannot be function or type) reserved reserved reserved +SNAPSHOT non-reserved +SOME reserved reserved reserved reserved +SOURCE non-reserved non-reserved +SPACE non-reserved non-reserved reserved +SPECIFIC reserved reserved +SPECIFICTYPE reserved reserved +SPECIFIC_NAME non-reserved non-reserved +SQL reserved reserved reserved +SQLCODE reserved +SQLERROR reserved +SQLEXCEPTION reserved reserved +SQLSTATE reserved reserved reserved +SQLWARNING reserved reserved +SQRT reserved reserved +STABLE non-reserved +STANDALONE non-reserved non-reserved non-reserved +START non-reserved reserved reserved +STATE non-reserved non-reserved +STATEMENT non-reserved non-reserved non-reserved +STATIC reserved reserved +STATISTICS non-reserved +STDDEV_POP reserved reserved +STDDEV_SAMP reserved reserved +STDIN non-reserved +STDOUT non-reserved +STORAGE non-reserved +STRICT non-reserved +STRIP non-reserved non-reserved non-reserved +STRUCTURE non-reserved non-reserved +STYLE non-reserved non-reserved +SUBCLASS_ORIGIN non-reserved non-reserved non-reserved +SUBMULTISET reserved reserved +SUBSTRING non-reserved (cannot be function or type) reserved reserved reserved +SUBSTRING_REGEX reserved reserved +SUCCEEDS reserved +SUM reserved reserved reserved +SYMMETRIC reserved reserved reserved +SYSID non-reserved +SYSTEM non-reserved reserved reserved +SYSTEM_TIME reserved +SYSTEM_USER reserved reserved reserved +T non-reserved non-reserved +TABLE reserved reserved reserved reserved +TABLES non-reserved +TABLESAMPLE reserved reserved +TABLESPACE non-reserved +TABLE_NAME non-reserved non-reserved non-reserved +TEMP non-reserved +TEMPLATE non-reserved +TEMPORARY non-reserved non-reserved non-reserved reserved +TEXT non-reserved +THEN reserved reserved reserved reserved +TIES non-reserved non-reserved +TIME non-reserved (cannot be function or type) reserved reserved reserved +TIMESTAMP non-reserved (cannot be function or type) reserved reserved reserved +TIMEZONE_HOUR reserved reserved reserved +TIMEZONE_MINUTE reserved reserved reserved +TO reserved reserved reserved reserved +TOKEN non-reserved non-reserved +TOP_LEVEL_COUNT non-reserved non-reserved +TRAILING reserved reserved reserved reserved +TRANSACTION non-reserved non-reserved non-reserved reserved +TRANSACTIONS_COMMITTED non-reserved non-reserved +TRANSACTIONS_ROLLED_BACK non-reserved non-reserved +TRANSACTION_ACTIVE non-reserved non-reserved +TRANSFORM non-reserved non-reserved +TRANSFORMS non-reserved non-reserved +TRANSLATE reserved reserved reserved +TRANSLATE_REGEX reserved reserved +TRANSLATION reserved reserved reserved +TREAT non-reserved (cannot be function or type) reserved reserved +TRIGGER non-reserved reserved reserved +TRIGGER_CATALOG non-reserved non-reserved +TRIGGER_NAME non-reserved non-reserved +TRIGGER_SCHEMA non-reserved non-reserved +TRIM non-reserved (cannot be function or type) reserved reserved reserved +TRIM_ARRAY reserved reserved +TRUE reserved reserved reserved reserved +TRUNCATE non-reserved reserved reserved +TRUSTED non-reserved +TYPE non-reserved non-reserved non-reserved non-reserved +TYPES non-reserved +UESCAPE reserved reserved +UNBOUNDED non-reserved non-reserved non-reserved +UNCOMMITTED non-reserved non-reserved non-reserved non-reserved +UNDER non-reserved non-reserved +UNENCRYPTED non-reserved +UNION reserved reserved reserved reserved +UNIQUE reserved reserved reserved reserved +UNKNOWN non-reserved reserved reserved reserved +UNLINK non-reserved non-reserved +UNLISTEN non-reserved +UNLOGGED non-reserved +UNNAMED non-reserved non-reserved non-reserved +UNNEST reserved reserved +UNTIL non-reserved +UNTYPED non-reserved non-reserved +UPDATE non-reserved reserved reserved reserved +UPPER reserved reserved reserved +URI non-reserved non-reserved +USAGE non-reserved non-reserved reserved +USER reserved reserved reserved reserved +USER_DEFINED_TYPE_CATALOG non-reserved non-reserved +USER_DEFINED_TYPE_CODE non-reserved non-reserved +USER_DEFINED_TYPE_NAME non-reserved non-reserved +USER_DEFINED_TYPE_SCHEMA non-reserved non-reserved +USING reserved reserved reserved reserved +VACUUM non-reserved +VALID non-reserved non-reserved non-reserved +VALIDATE non-reserved +VALIDATOR non-reserved +VALUE non-reserved reserved reserved reserved +VALUES non-reserved (cannot be function or type) reserved reserved reserved +VALUE_OF reserved +VARBINARY reserved reserved +VARCHAR non-reserved (cannot be function or type) reserved reserved reserved +VARIADIC reserved +VARYING non-reserved reserved reserved reserved +VAR_POP reserved reserved +VAR_SAMP reserved reserved +VERBOSE reserved (can be function or type) +VERSION non-reserved non-reserved non-reserved +VERSIONING reserved +VIEW non-reserved non-reserved non-reserved reserved +VOLATILE non-reserved +WHEN reserved reserved reserved reserved +WHENEVER reserved reserved reserved +WHERE reserved reserved reserved reserved +WHITESPACE non-reserved non-reserved non-reserved +WIDTH_BUCKET reserved reserved +WINDOW reserved reserved reserved +WITH reserved reserved reserved reserved +WITHIN reserved reserved +WITHOUT non-reserved reserved reserved +WORK non-reserved non-reserved non-reserved reserved +WRAPPER non-reserved non-reserved non-reserved +WRITE non-reserved non-reserved non-reserved reserved +XML non-reserved reserved reserved +XMLAGG reserved reserved +XMLATTRIBUTES non-reserved (cannot be function or type) reserved reserved +XMLBINARY reserved reserved +XMLCAST reserved reserved +XMLCOMMENT reserved reserved +XMLCONCAT non-reserved (cannot be function or type) reserved reserved +XMLDECLARATION non-reserved non-reserved +XMLDOCUMENT reserved reserved +XMLELEMENT non-reserved (cannot be function or type) reserved reserved +XMLEXISTS non-reserved (cannot be function or type) reserved reserved +XMLFOREST non-reserved (cannot be function or type) reserved reserved +XMLITERATE reserved reserved +XMLNAMESPACES reserved reserved +XMLPARSE non-reserved (cannot be function or type) reserved reserved +XMLPI non-reserved (cannot be function or type) reserved reserved +XMLQUERY reserved reserved +XMLROOT non-reserved (cannot be function or type) +XMLSCHEMA non-reserved non-reserved +XMLSERIALIZE non-reserved (cannot be function or type) reserved reserved +XMLTABLE reserved reserved +XMLTEXT reserved reserved +XMLVALIDATE reserved reserved +YEAR non-reserved reserved reserved reserved +YES non-reserved non-reserved non-reserved +ZONE non-reserved non-reserved non-reserved reserved \ No newline at end of file diff --git a/dialect_postgres.go b/dialects/postgres.go similarity index 83% rename from dialect_postgres.go rename to dialects/postgres.go index b70e345..a325dfb 100644 --- a/dialect_postgres.go +++ b/dialects/postgres.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "errors" "fmt" "net/url" "strconv" "strings" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) // from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html @@ -765,71 +767,107 @@ var ( "ZONE": true, } + postgresQuoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } +) + +var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" ) -const postgresPublicSchema = "public" - type postgres struct { - core.Base + Base } -func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) - if err != nil { - return err +func (db *postgres) Init(uri *URI) error { + db.quoter = postgresQuoter + return db.Base.Init(db, uri) +} + +func (db *postgres) getSchema() string { + if db.uri.Schema != "" { + return db.uri.Schema } - if db.Schema == "" { - db.Schema = DefaultPostgresSchema + return DefaultPostgresSchema +} + +func (db *postgres) needQuote(name string) bool { + if db.IsReserved(name) { + return true } - return nil + for _, c := range name { + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } -func (db *postgres) SqlType(c *core.Column) string { +func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = postgresQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = postgresQuoter + q.IsReserved = db.needQuote + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = postgresQuoter + } +} + +func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.TinyInt: - res = core.SmallInt + case schemas.TinyInt: + res = schemas.SmallInt return res - case core.Bit: - res = core.Boolean + case schemas.Bit: + res = schemas.Boolean return res - case core.MediumInt, core.Int, core.Integer: + case schemas.MediumInt, schemas.Int, schemas.Integer: if c.IsAutoIncrement { - return core.Serial + return schemas.Serial } - return core.Integer - case core.BigInt: + return schemas.Integer + case schemas.BigInt: if c.IsAutoIncrement { - return core.BigSerial + return schemas.BigSerial } - return core.BigInt - case core.Serial, core.BigSerial: + return schemas.BigInt + case schemas.Serial, schemas.BigSerial: c.IsAutoIncrement = true c.Nullable = false res = t - case core.Binary, core.VarBinary: - return core.Bytea - case core.DateTime: - res = core.TimeStamp - case core.TimeStampz: + case schemas.Binary, schemas.VarBinary: + return schemas.Bytea + case schemas.DateTime: + res = schemas.TimeStamp + case schemas.TimeStampz: return "timestamp with time zone" - case core.Float: - res = core.Real - case core.TinyText, core.MediumText, core.LongText: - res = core.Text - case core.NVarchar: - res = core.Varchar - case core.Uuid: - return core.Uuid - case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: - return core.Bytea - case core.Double: + case schemas.Float: + res = schemas.Real + case schemas.TinyText, schemas.MediumText, schemas.LongText: + res = schemas.Text + case schemas.NVarchar: + res = schemas.Varchar + case schemas.Uuid: + return schemas.Uuid + case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + return schemas.Bytea + case schemas.Double: return "DOUBLE PRECISION" default: if c.IsAutoIncrement { - return core.Serial + return schemas.Serial } res = t } @@ -849,85 +887,80 @@ func (db *postgres) SqlType(c *core.Column) string { return res } -func (db *postgres) SupportInsertMany() bool { - return true -} - func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } -func (db *postgres) Quote(name string) string { - name = strings.Replace(name, ".", `"."`, -1) - return "\"" + name + "\"" -} - func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) SupportEngine() bool { - return false -} +func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } -func (db *postgres) SupportCharset() bool { - return false -} + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" -func (db *postgres) IndexOnTable() bool { - return false + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return []string{sql}, true } -func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - if len(db.Schema) == 0 { +func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args } - args := []interface{}{db.Schema, tableName, idxName} + args := []interface{}{db.getSchema(), tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - if len(db.Schema) == 0 { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args +func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + if len(db.getSchema()) == 0 { + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } - args := []interface{}{db.Schema, tableName} - return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, + db.getSchema(), tableName) } -func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { - if len(db.Schema) == 0 || strings.Contains(tableName, ".") { +func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + tableName, col.Name, db.SQLType(col)) } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.Schema, tableName, col.Name, db.SqlType(col)) + db.getSchema(), tableName, col.Name, db.SQLType(col)) } -func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string { - quote := db.Quote - var unique string - var idxName string - if index.Type == core.UniqueType { - unique = " UNIQUE" - } - idxName = index.XName(tableName) - if db.Uri.Schema != "" { - idxName = db.Uri.Schema + "." + idxName - } - - return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, quote(",")))) -} - -func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - quote := db.Quote +func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { idxName := index.Name tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") @@ -935,30 +968,29 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } else { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - if db.Uri.Schema != "" { - idxName = db.Uri.Schema + "." + idxName + if db.getSchema() != "" { + idxName = db.getSchema() + "." + idxName } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{db.Schema, tableName, colName} +func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + args := []interface{}{db.getSchema(), tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" - if len(db.Schema) == 0 { + if len(db.getSchema()) == 0 { args = []interface{}{tableName, colName} query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" } - db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -967,9 +999,9 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { return rows.Next(), nil } -func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} - s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey FROM pg_attribute f @@ -979,28 +1011,27 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` - var f string - if len(db.Schema) != 0 { - args = append(args, db.Schema) - f = " AND s.table_schema = $2" + schema := db.getSchema() + if schema != "" { + s = fmt.Sprintf(s, "AND s.table_schema = $2") + args = append(args, schema) + } else { + s = fmt.Sprintf(s, "") } - s = fmt.Sprintf(s, f) - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var colName, isNullable, dataType string @@ -1011,7 +1042,6 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att return nil, nil, err } - // fmt.Println(args, colName, isNullable, dataType, maxLenStr, colDefault, isPK, isUnique) var maxLen int if maxLenStr != nil { maxLen, err = strconv.Atoi(*maxLenStr) @@ -1023,10 +1053,27 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Name = strings.Trim(colName, `" `) if colDefault != nil { - col.Default = *colDefault + var theDefault = *colDefault + // cockroach has type with the default value with ::: + // and postgres with ::, we should remove them before store them + idx := strings.Index(theDefault, ":::") + if idx == -1 { + idx = strings.Index(theDefault, "::") + } + if idx > -1 { + theDefault = theDefault[:idx] + } + + if strings.HasSuffix(theDefault, "+00:00'") { + theDefault = theDefault[:len(theDefault)-7] + "'" + } + + col.Default = theDefault col.DefaultIsEmpty = false if strings.HasPrefix(col.Default, "nextval(") { col.IsAutoIncrement = true + col.Default = "" + col.DefaultIsEmpty = true } } else { col.DefaultIsEmpty = true @@ -1038,26 +1085,37 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Nullable = (isNullable == "YES") - switch dataType { - case "character varying", "character": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 0, DefaultLength2: 0} + switch strings.ToLower(dataType) { + case "character varying", "character", "string": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 0, DefaultLength2: 0} case "timestamp without time zone": - col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0} case "timestamp with time zone": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "double precision": - col.SQLType = core.SQLType{Name: core.Double, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0} case "boolean": - col.SQLType = core.SQLType{Name: core.Bool, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0} case "time without time zone": - col.SQLType = core.SQLType{Name: core.Time, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0} + case "bytes": + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} case "oid": - col.SQLType = core.SQLType{Name: core.BigInt, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0} + case "array": + col.SQLType = schemas.SQLType{Name: schemas.Array, DefaultLength: 0, DefaultLength2: 0} default: - col.SQLType = core.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} + startIdx := strings.Index(strings.ToLower(dataType), "string(") + if startIdx != -1 && strings.HasSuffix(dataType, ")") { + length := dataType[startIdx+8 : len(dataType)-1] + l, _ := strconv.Atoi(length) + col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0} + } else { + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} + } } - if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %v", dataType) + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { + return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1082,25 +1140,24 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att return colSeq, cols, nil } -func (db *postgres) GetTables() ([]*core.Table, error) { +func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" - if len(db.Schema) != 0 { - args = append(args, db.Schema) + schema := db.getSchema() + if schema != "" { + args = append(args, schema) s = s + " WHERE schemaname = $1" } - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -1123,22 +1180,21 @@ func getIndexColName(indexdef string) []string { return colNames } -func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - if len(db.Schema) != 0 { - args = append(args, db.Schema) + if len(db.getSchema()) != 0 { + args = append(args, db.getSchema()) s = s + " AND schemaname=$2" } - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, indexdef string @@ -1147,14 +1203,18 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) if err != nil { return nil, err } + + if indexName == "primary" { + continue + } indexName = strings.Trim(indexName, `" `) if strings.HasSuffix(indexName, "_pkey") { continue } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } colNames = getIndexColName(indexdef) var isRegular bool @@ -1166,9 +1226,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) } } - index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} for _, colName := range colNames { - index.Cols = append(index.Cols, strings.Trim(colName, `" `)) + index.Cols = append(index.Cols, strings.TrimSpace(strings.Replace(colName, `"`, "", -1))) } index.IsRegular = isRegular indexes[index.Name] = index @@ -1176,8 +1236,8 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) return indexes, nil } -func (db *postgres) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &core.SeqFilter{Prefix: "$", Start: 1}} +func (db *postgres) Filters() []Filter { + return []Filter{&SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { @@ -1231,12 +1291,12 @@ func parseOpts(name string, o values) error { return nil } -func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.POSTGRES} +func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.POSTGRES} var err error if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { - db.DbName, err = parseURL(dataSourceName) + db.DBName, err = parseURL(dataSourceName) if err != nil { return nil, err } @@ -1247,10 +1307,10 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { return nil, err } - db.DbName = o.Get("dbname") + db.DBName = o.Get("dbname") } - if db.DbName == "" { + if db.DBName == "" { return nil, errors.New("dbname is empty") } @@ -1261,10 +1321,29 @@ type pqDriverPgx struct { pqDriver } -func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) { // Remove the leading characters for driver to work if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { dataSourceName = dataSourceName[9:] } return pgx.pqDriver.Parse(driverName, dataSourceName) } + +// QueryDefaultPostgresSchema returns the default postgres schema +func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH") + if err != nil { + return "", err + } + defer rows.Close() + if rows.Next() { + var defaultSchema string + if err = rows.Scan(&defaultSchema); err != nil { + return "", err + } + parts := strings.Split(defaultSchema, ",") + return strings.TrimSpace(parts[len(parts)-1]), nil + } + + return "", errors.New("No default schema") +} diff --git a/dialect_postgres_test.go b/dialects/postgres_test.go similarity index 91% rename from dialect_postgres_test.go rename to dialects/postgres_test.go index e63021b..c0a8eb6 100644 --- a/dialect_postgres_test.go +++ b/dialects/postgres_test.go @@ -1,11 +1,10 @@ -package xorm +package dialects import ( "reflect" "testing" "github.com/stretchr/testify/assert" - "github.com/xormplus/core" ) func TestParsePostgres(t *testing.T) { @@ -27,15 +26,15 @@ func TestParsePostgres(t *testing.T) { {"dbname=db =disable", "db", false}, } - driver := core.QueryDriver("postgres") + driver := QueryDriver("postgres") for _, test := range tests { uri, err := driver.Parse("postgres", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } } @@ -59,23 +58,23 @@ func TestParsePgx(t *testing.T) { {"dbname=db =disable", "db", false}, } - driver := core.QueryDriver("pgx") + driver := QueryDriver("pgx") for _, test := range tests { uri, err := driver.Parse("pgx", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } // Register DriverConfig uri, err = driver.Parse("pgx", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } diff --git a/dialects/quote.go b/dialects/quote.go new file mode 100644 index 0000000..da4e0dd --- /dev/null +++ b/dialects/quote.go @@ -0,0 +1,15 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +// QuotePolicy describes quote handle policy +type QuotePolicy int + +// All QuotePolicies +const ( + QuotePolicyAlways QuotePolicy = iota + QuotePolicyNone + QuotePolicyReserved +) diff --git a/dialect_sqlite3.go b/dialects/sqlite3.go similarity index 67% rename from dialect_sqlite3.go rename to dialects/sqlite3.go index 2369daa..435901b 100644 --- a/dialect_sqlite3.go +++ b/dialects/sqlite3.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "database/sql" "errors" "fmt" "regexp" "strings" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) var ( @@ -141,45 +143,69 @@ var ( "WITH": true, "WITHOUT": true, } + + sqlite3Quoter = schemas.Quoter{ + Prefix: '`', + Suffix: '`', + IsReserved: schemas.AlwaysReserve, + } ) type sqlite3 struct { - core.Base + Base +} + +func (db *sqlite3) Init(uri *URI) error { + db.quoter = sqlite3Quoter + return db.Base.Init(db, uri) } -func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = sqlite3Quoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = sqlite3Quoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = sqlite3Quoter + } } -func (db *sqlite3) SqlType(c *core.Column) string { +func (db *sqlite3) SQLType(c *schemas.Column) string { switch t := c.SQLType.Name; t { - case core.Bool: + case schemas.Bool: if c.Default == "true" { c.Default = "1" } else if c.Default == "false" { c.Default = "0" } - return core.Integer - case core.Date, core.DateTime, core.TimeStamp, core.Time: - return core.DateTime - case core.TimeStampz: - return core.Text - case core.Char, core.Varchar, core.NVarchar, core.TinyText, - core.Text, core.MediumText, core.LongText, core.Json: - return core.Text - case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt: - return core.Integer - case core.Float, core.Double, core.Real: - return core.Real - case core.Decimal, core.Numeric: - return core.Numeric - case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary: - return core.Blob - case core.Serial, core.BigSerial: + return schemas.Integer + case schemas.Date, schemas.DateTime, schemas.TimeStamp, schemas.Time: + return schemas.DateTime + case schemas.TimeStampz: + return schemas.Text + case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText, + schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: + return schemas.Text + case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt: + return schemas.Integer + case schemas.Float, schemas.Double, schemas.Real: + return schemas.Real + case schemas.Decimal, schemas.Numeric: + return schemas.Numeric + case schemas.TinyBlob, schemas.Blob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea, schemas.Binary, schemas.VarBinary: + return schemas.Blob + case schemas.Serial, schemas.BigSerial: c.IsPrimaryKey = true c.IsAutoIncrement = true c.Nullable = false - return core.Integer + return schemas.Integer default: return t } @@ -189,97 +215,97 @@ func (db *sqlite3) FormatBytes(bs []byte) string { return fmt.Sprintf("X'%x'", bs) } -func (db *sqlite3) SupportInsertMany() bool { - return true -} - func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } -func (db *sqlite3) Quote(name string) string { - return "`" + name + "`" -} - func (db *sqlite3) AutoIncrStr() string { return "AUTOINCREMENT" } -func (db *sqlite3) SupportEngine() bool { - return false -} - -func (db *sqlite3) SupportCharset() bool { - return false -} - -func (db *sqlite3) IndexOnTable() bool { - return false -} - -func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } -func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) } -func (db *sqlite3) CreateIndexSql(tableName string, index *core.Index) string { - quote := db.Quote - var unique string - var idxName string - if index.Type == core.UniqueType { - unique = " UNIQUE" - } - idxName = index.XName(tableName) - return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, quote(",")))) -} - -func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { - //var unique string - quote := db.Quote +func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { + // var unique string idxName := index.Name if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } else { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) ForUpdateSql(query string) string { - return query +func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return []string{sql}, true } -/*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args -}*/ +func (db *sqlite3) ForUpdateSQL(query string) string { + return query +} -func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{tableName} - query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) +func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + query := "SELECT * FROM " + tableName + " LIMIT 0" + rows, err := queryer.QueryContext(ctx, query) if err != nil { return false, err } defer rows.Close() - if rows.Next() { - return true, nil + cols, err := rows.Columns() + if err != nil { + return false, err + } + + for _, col := range cols { + if strings.EqualFold(col, colName) { + return true, nil + } } + return false, nil } @@ -311,9 +337,9 @@ func splitColStr(colStr string) []string { return results } -func parseString(colStr string) (*core.Column, error) { +func parseString(colStr string) (*schemas.Column, error) { fields := splitColStr(colStr) - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) col.Nullable = true col.DefaultIsEmpty = true @@ -323,7 +349,7 @@ func parseString(colStr string) (*core.Column, error) { col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) continue } else if idx == 1 { - col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} continue } switch field { @@ -345,11 +371,11 @@ func parseString(colStr string) (*core.Column, error) { return col, nil } -func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -372,8 +398,9 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu nEnd := strings.LastIndex(name, ")") reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`) colCreates := reg.FindAllString(name[nStart+1:nEnd], -1) - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) + for _, colStr := range colCreates { reg = regexp.MustCompile(`,\s`) colStr = reg.ReplaceAllString(colStr, ",") @@ -401,20 +428,19 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu return colSeq, cols, nil } -func (db *sqlite3) GetTables() ([]*core.Table, error) { +func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -427,18 +453,17 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) @@ -451,7 +476,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) } sql := tmpSQL.String - index := new(core.Index) + index := new(schemas.Index) nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") if nNStart == -1 || nNEnd == -1 { @@ -468,9 +493,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) } if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = core.UniqueType + index.Type = schemas.UniqueType } else { - index.Type = core.IndexType + index.Type = schemas.IndexType } nStart := strings.Index(sql, "(") @@ -488,17 +513,17 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) return indexes, nil } -func (db *sqlite3) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}} +func (db *sqlite3) Filters() []Filter { + return []Filter{} } type sqlite3Driver struct { } -func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { if strings.Contains(dataSourceName, "?") { dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")] } - return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil + return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } diff --git a/dialect_sqlite3_test.go b/dialects/sqlite3_test.go similarity index 97% rename from dialect_sqlite3_test.go rename to dialects/sqlite3_test.go index a203615..aa6c3ce 100644 --- a/dialect_sqlite3_test.go +++ b/dialects/sqlite3_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "testing" diff --git a/dialects/table_name.go b/dialects/table_name.go new file mode 100644 index 0000000..3e0c852 --- /dev/null +++ b/dialects/table_name.go @@ -0,0 +1,89 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" + "reflect" + "strings" + + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/names" +) + +// TableNameWithSchema will add schema prefix on table name if possible +func TableNameWithSchema(dialect Dialect, tableName string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if dialect.URI().Schema != "" && + strings.Index(tableName, ".") == -1 { + return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) + } + return tableName +} + +// TableNameNoSchema returns table name with given tableName +func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string { + quote := dialect.Quoter().Quote + switch tableName.(type) { + case []string: + t := tableName.([]string) + if len(t) > 1 { + return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1])) + } else if len(t) == 1 { + return quote(t[0]) + } + case []interface{}: + t := tableName.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case names.TableName: + table = f.(names.TableName).TableName() + default: + v := utils.ReflectValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + table = names.GetTableName(mapper, v) + } else { + table = quote(fmt.Sprintf("%v", f)) + } + } + } + if l > 1 { + return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + return quote(table) + } + case names.TableName: + return tableName.(names.TableName).TableName() + case string: + return tableName.(string) + case reflect.Value: + v := tableName.(reflect.Value) + return names.GetTableName(mapper, v) + default: + v := utils.ReflectValue(tableName) + t := v.Type() + if t.Kind() == reflect.Struct { + return names.GetTableName(mapper, v) + } + return quote(fmt.Sprintf("%v", tableName)) + } + return "" +} + +// FullTableName returns table name with quote and schema according parameter +func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string { + tbName := TableNameNoSchema(dialect, mapper, bean) + if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) { + tbName = TableNameWithSchema(dialect, tbName) + } + return tbName +} diff --git a/engine_table_test.go b/dialects/table_name_test.go similarity index 59% rename from engine_table_test.go rename to dialects/table_name_test.go index 8f2300a..7d0f27f 100644 --- a/engine_table_test.go +++ b/dialects/table_name_test.go @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "testing" + "github.com/xormplus/xorm/names" + "github.com/stretchr/testify/assert" ) @@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string { return "mcc" } -func TestTableName1(t *testing.T) { - assert.NoError(t, prepareEngine()) +func TestFullTableName(t *testing.T) { + dialect := QueryDialect("mysql") - assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC))) - assert.EqualValues(t, "mcc", testEngine.TableName("mcc")) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{})) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc")) } diff --git a/dialects/time.go b/dialects/time.go new file mode 100644 index 0000000..f06e749 --- /dev/null +++ b/dialects/time.go @@ -0,0 +1,49 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "time" + + "github.com/xormplus/xorm/schemas" +) + +// FormatTime format time as column type +func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) { + switch sqlTypeName { + case schemas.Time: + s := t.Format("2006-01-02 15:04:05") // time.RFC3339 + v = s[11:19] + case schemas.Date: + v = t.Format("2006-01-02") + case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. + v = t.Format("2006-01-02 15:04:05") + case schemas.TimeStampz: + if dialect.URI().DBType == schemas.MSSQL { + v = t.Format("2006-01-02T15:04:05.9999999Z07:00") + } else { + v = t.Format(time.RFC3339Nano) + } + case schemas.BigInt, schemas.Int: + v = t.Unix() + default: + v = t + } + return +} + +func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { + if t.IsZero() { + if col.Nullable { + return nil + } + return "" + } + + if col.TimeZone != nil { + return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone)) + } + return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone)) +} diff --git a/engine.go b/engine.go index f9b48ff..45e0766 100644 --- a/engine.go +++ b/engine.go @@ -5,11 +5,8 @@ package xorm import ( - "bufio" - "bytes" "context" "database/sql" - "encoding/gob" "errors" "fmt" "io" @@ -17,73 +14,62 @@ import ( "reflect" "strconv" "strings" - "sync" "time" "github.com/fsnotify/fsnotify" - "github.com/xormplus/builder" - "github.com/xormplus/core" + + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" + "github.com/xormplus/xorm/tags" ) // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - db *core.DB - dialect core.Dialect - - ColumnMapper core.IMapper - TableMapper core.IMapper - TagIdentifier string - Tables map[reflect.Type]*core.Table - SqlMap SqlMap - SqlTemplate SqlTemplate - watcher *fsnotify.Watcher - mutex *sync.RWMutex - Cacher core.Cacher - - showSQL bool - showExecTime bool - - logger core.ILogger - TZLocation *time.Location // The timezone of the application - DatabaseTZ *time.Location // The timezone of the database - - disableGlobalCache bool + cacherMgr *caches.Manager + defaultContext context.Context + dialect dialects.Dialect + engineGroup *EngineGroup + logger log.ContextLogger + tagParser *tags.Parser + db *core.DB - tagHandlers map[string]tagHandler + driverName string + dataSourceName string - engineGroup *EngineGroup + SqlMap SqlMap + SqlTemplate SqlTemplate + watcher *fsnotify.Watcher - cachers map[string]core.Cacher - cacherLock sync.RWMutex + TZLocation *time.Location // The timezone of the application + DatabaseTZ *time.Location // The timezone of the database - defaultContext context.Context + logSessionID bool // create session id } -func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { - engine.cacherLock.Lock() - engine.cachers[tableName] = cacher - engine.cacherLock.Unlock() +// EnableSessionID if enable session id +func (engine *Engine) EnableSessionID(enable bool) { + engine.logSessionID = enable } -func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { - engine.setCacher(tableName, cacher) +// SetCacher sets cacher for the table +func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) { + engine.cacherMgr.SetCacher(tableName, cacher) } -func (engine *Engine) getCacher(tableName string) core.Cacher { - var cacher core.Cacher - var ok bool - engine.cacherLock.RLock() - cacher, ok = engine.cachers[tableName] - engine.cacherLock.RUnlock() - if !ok && !engine.disableGlobalCache { - cacher = engine.Cacher - } - return cacher +// GetCacher returns the cachher of the special table +func (engine *Engine) GetCacher(tableName string) caches.Cacher { + return engine.cacherMgr.GetCacher(tableName) } -func (engine *Engine) GetCacher(tableName string) core.Cacher { - return engine.getCacher(tableName) +// SetQuotePolicy sets the special quote policy +func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + engine.dialect.SetQuotePolicy(quotePolicy) } // BufferSize sets buffer size for iterate @@ -93,108 +79,64 @@ func (engine *Engine) BufferSize(size int) *Session { return session.BufferSize(size) } -// CondDeleted returns the conditions whether a record is soft deleted. -func (engine *Engine) CondDeleted(col *core.Column) builder.Cond { - var cond = builder.NewCond() - if col.SQLType.IsNumeric() { - cond = builder.Eq{col.Name: 0} - } else { - // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. - if engine.dialect.DBType() != core.MSSQL { - cond = builder.Eq{col.Name: zeroTime1} - } - } - - if col.Nullable { - cond = cond.Or(builder.IsNull{col.Name}) - } - - return cond -} - // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - if len(show) == 0 { - engine.showSQL = true - } else { - engine.showSQL = show[0] - } -} - -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (engine *Engine) ShowExecTime(show ...bool) { - if len(show) == 0 { - engine.showExecTime = true - } else { - engine.showExecTime = show[0] - } + engine.DB().Logger = engine.logger } // Logger return the logger interface -func (engine *Engine) Logger() core.ILogger { +func (engine *Engine) Logger() log.ContextLogger { return engine.logger } // SetLogger set the new logger -func (engine *Engine) SetLogger(logger core.ILogger) { - engine.logger = logger - engine.showSQL = logger.IsShowSQL() - engine.dialect.SetLogger(logger) +func (engine *Engine) SetLogger(logger interface{}) { + var realLogger log.ContextLogger + switch t := logger.(type) { + case log.Logger: + realLogger = log.NewLoggerAdapter(t) + case log.ContextLogger: + realLogger = t + } + engine.logger = realLogger + engine.DB().Logger = realLogger } // SetLogLevel sets the logger level -func (engine *Engine) SetLogLevel(level core.LogLevel) { +func (engine *Engine) SetLogLevel(level log.LogLevel) { engine.logger.SetLevel(level) } // SetDisableGlobalCache disable global cache or not func (engine *Engine) SetDisableGlobalCache(disable bool) { - if engine.disableGlobalCache != disable { - engine.disableGlobalCache = disable - } + engine.cacherMgr.SetDisableGlobalCache(disable) } // DriverName return the current sql driver's name func (engine *Engine) DriverName() string { - return engine.dialect.DriverName() + return engine.driverName } // DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { - return engine.dialect.DataSourceName() + return engine.dataSourceName } // SetMapper set the name mapping rules -func (engine *Engine) SetMapper(mapper core.IMapper) { +func (engine *Engine) SetMapper(mapper names.Mapper) { engine.SetTableMapper(mapper) engine.SetColumnMapper(mapper) } // SetTableMapper set the table name mapping rule -func (engine *Engine) SetTableMapper(mapper core.IMapper) { - engine.TableMapper = mapper +func (engine *Engine) SetTableMapper(mapper names.Mapper) { + engine.tagParser.SetTableMapper(mapper) } // SetColumnMapper set the column name mapping rule -func (engine *Engine) SetColumnMapper(mapper core.IMapper) { - engine.ColumnMapper = mapper -} - -// SupportInsertMany If engine's database support batch insert records like -// "insert into user values (name, age), (name, age)". -// When the return is ture, then engine.Insert(&users) will -// generate batch sql and exeute. -func (engine *Engine) SupportInsertMany() bool { - return engine.dialect.SupportInsertMany() -} - -func (engine *Engine) quoteColumns(columnStr string) string { - columns := strings.Split(columnStr, ",") - for i := 0; i < len(columns); i++ { - columns[i] = engine.Quote(strings.TrimSpace(columns[i])) - } - return strings.Join(columns, ",") +func (engine *Engine) SetColumnMapper(mapper names.Mapper) { + engine.tagParser.SetColumnMapper(mapper) } // Quote Use QuoteStr quote the string sql @@ -220,64 +162,12 @@ func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { if value == "" { return } - - quoteTo(buf, engine.dialect.Quote(""), value) -} - -func quoteTo(buf *strings.Builder, quotePair string, value string) { - if len(quotePair) < 2 { // no quote - _, _ = buf.WriteString(value) - return - } - - prefix, suffix := quotePair[0], quotePair[1] - - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - i++ - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != ch; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - i++ - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.'; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - } - } -} - -func (engine *Engine) quote(sql string) string { - return engine.dialect.Quote(sql) -} - -// SqlType will be deprecated, please use SQLType instead -// -// Deprecated: use SQLType instead -func (engine *Engine) SqlType(c *core.Column) string { - return engine.SQLType(c) + engine.dialect.Quoter().QuoteTo(buf, value) } // SQLType A simple wrapper to dialect's core.SqlType method -func (engine *Engine) SQLType(c *core.Column) string { - return engine.dialect.SqlType(c) +func (engine *Engine) SQLType(c *schemas.Column) string { + return engine.dialect.SQLType(c) } // AutoIncrStr Database's autoincrement statement @@ -287,27 +177,27 @@ func (engine *Engine) AutoIncrStr() string { // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { - engine.db.SetConnMaxLifetime(d) + engine.DB().SetConnMaxLifetime(d) } // SetMaxOpenConns is only available for go 1.2+ func (engine *Engine) SetMaxOpenConns(conns int) { - engine.db.SetMaxOpenConns(conns) + engine.DB().SetMaxOpenConns(conns) } // SetMaxIdleConns set the max idle connections on pool, default is 2 func (engine *Engine) SetMaxIdleConns(conns int) { - engine.db.SetMaxIdleConns(conns) + engine.DB().SetMaxIdleConns(conns) } // SetDefaultCacher set the default cacher. Xorm's default not enable cacher. -func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { - engine.Cacher = cacher +func (engine *Engine) SetDefaultCacher(cacher caches.Cacher) { + engine.cacherMgr.SetDefaultCacher(cacher) } // GetDefaultCacher returns the default cacher -func (engine *Engine) GetDefaultCacher() core.Cacher { - return engine.Cacher +func (engine *Engine) GetDefaultCacher() caches.Cacher { + return engine.cacherMgr.GetDefaultCacher() } // NoCache If you has set default cacher, and you want temporilly stop use cache, @@ -326,14 +216,14 @@ func (engine *Engine) NoCascade() *Session { } // MapCacher Set a table use a special cacher -func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { - engine.setCacher(engine.TableName(bean, true), cacher) +func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { + engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher) return nil } // NewDB provides an interface to operate database directly func (engine *Engine) NewDB() (*core.DB, error) { - return core.OpenDialect(engine.dialect) + return core.Open(engine.driverName, engine.dataSourceName) } // DB return the wrapper of sql.DB @@ -342,20 +232,18 @@ func (engine *Engine) DB() *core.DB { } // Dialect return database dialect -func (engine *Engine) Dialect() core.Dialect { +func (engine *Engine) Dialect() dialects.Dialect { return engine.dialect } // NewSession New a session func (engine *Engine) NewSession() *Session { - session := &Session{engine: engine} - session.Init() - return session + return newSession(engine) } // Close the engine func (engine *Engine) Close() error { - return engine.db.Close() + return engine.DB().Close() } // Ping tests if database is alive @@ -365,25 +253,6 @@ func (engine *Engine) Ping() error { return session.Ping() } -// logSQL save sql -func (engine *Engine) logSQL(session *Session, sqlStr string, sqlArgs ...interface{}) { - if engine.showSQL && !engine.showExecTime { - if len(sqlArgs) > 0 { - engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, sqlArgs) - } else { - engine.logger.Infof("[SQL][%p] %v", session, sqlStr) - } - } -} - -// Sql provides raw sql input parameter. When you have a complex SQL statement -// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. -// -// Deprecated: use SQL instead. -func (engine *Engine) Sql(query interface{}, args ...interface{}) *Session { - return engine.SQL(query, args...) -} - // SQL method let's you manually write raw SQL and operate // For example: // @@ -418,26 +287,33 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { return session.NoAutoCondition(no...) } -func (engine *Engine) loadTableInfo(table *core.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(table.Name) +func (engine *Engine) loadTableInfo(table *schemas.Table) error { + colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(table.Name) + indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name) if err != nil { return err } table.Indexes = indexes + var seq int for _, index := range indexes { for _, name := range index.Cols { - if col := table.GetColumn(name); col != nil { + parts := strings.Split(name, " ") + if len(parts) > 1 { + if parts[1] == "DESC" { + seq = 1 + } + } + if col := table.GetColumn(parts[0]); col != nil { col.Indexes[index.Name] = index.Type } else { - return fmt.Errorf("Unknown col %s in index %v of table %v, columns %v", name, index.Name, table.Name, table.ColumnsSeq()) + return fmt.Errorf("Unknown col %s seq %d, in index %v of table %v, columns %v", name, seq, index.Name, table.Name, table.ColumnsSeq()) } } } @@ -445,8 +321,8 @@ func (engine *Engine) loadTableInfo(table *core.Table) error { } // DBMetas Retrieve all tables, columns, indexes' informations from database. -func (engine *Engine) DBMetas() ([]*core.Table, error) { - tables, err := engine.dialect.GetTables() +func (engine *Engine) DBMetas() ([]*schemas.Table, error) { + tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext) if err != nil { return nil, err } @@ -460,7 +336,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) { } // DumpAllToFile dump database all table structs and data to a file -func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error { +func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -470,7 +346,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error { } // DumpAll dump database all table structs and data to w -func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error { +func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error { tables, err := engine.DBMetas() if err != nil { return err @@ -479,7 +355,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error { } // DumpTablesToFile dump specified tables to SQL file. -func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...core.DbType) error { +func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -489,55 +365,70 @@ func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...co } // DumpTables dump specify tables to io.Writer -func (engine *Engine) DumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { +func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { return engine.dumpTables(tables, w, tp...) } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { - var dialect core.Dialect - var distDBName string +func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { + var dstDialect dialects.Dialect if len(tp) == 0 { - dialect = engine.dialect - distDBName = string(engine.dialect.DBType()) + dstDialect = engine.dialect } else { - dialect = core.QueryDialect(tp[0]) - if dialect == nil { + dstDialect = dialects.QueryDialect(tp[0]) + if dstDialect == nil { return errors.New("Unsupported database type") } - dialect.Init(nil, engine.dialect.URI(), "", "") - distDBName = string(tp[0]) + + uri := engine.dialect.URI() + destURI := *uri + dstDialect.Init(&destURI) } - _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n", - Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.DBType(), strings.ToUpper(distDBName))) + _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", + time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType)) if err != nil { return err } for i, table := range tables { + tableName := table.Name + if dstDialect.URI().Schema != "" { + tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name) + } + originalTableName := table.Name + if engine.dialect.URI().Schema != "" { + originalTableName = fmt.Sprintf("%s.%s", engine.dialect.URI().Schema, table.Name) + } if i > 0 { _, err = io.WriteString(w, "\n") if err != nil { return err } } - _, err = io.WriteString(w, dialect.CreateTableSql(table, "", table.StoreEngine, "")+";\n") - if err != nil { - return err + sqls, _ := dstDialect.CreateTableSQL(table, tableName) + for _, s := range sqls { + _, err = io.WriteString(w, s+";\n") + if err != nil { + return err + } + } + if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { + fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) } + for _, index := range table.Indexes { - _, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n") + _, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n") if err != nil { return err } } cols := table.ColumnsSeq() - colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) - destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := engine.dialect.Quoter().Join(cols, ", ") + destColNames := dstDialect.Quoter().Join(cols, ", ") - rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) + rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName)) if err != nil { return err } @@ -550,7 +441,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -573,26 +464,26 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } } else if col.SQLType.IsBlob() { if reflect.TypeOf(d).Kind() == reflect.Slice { - temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte))) + temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte))) } else if reflect.TypeOf(d).Kind() == reflect.String { temp += fmt.Sprintf(", '%s'", d.(string)) } } else if col.SQLType.IsNumeric() { switch reflect.TypeOf(d).Kind() { case reflect.Slice: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) } else { temp += fmt.Sprintf(", %s", string(d.([]byte))) } case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) } else { temp += fmt.Sprintf(", %v", d) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0)) } else { temp += fmt.Sprintf(", %v", d) @@ -620,8 +511,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } // FIXME: Hack for postgres - if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") + if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { + _, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n") if err != nil { return err } @@ -644,13 +535,6 @@ func (engine *Engine) Where(query interface{}, args ...interface{}) *Session { return session.Where(query, args...) } -// Id will be deprecated, please use ID instead -func (engine *Engine) Id(id interface{}) *Session { - session := engine.NewSession() - session.isAutoClose = true - return session.Id(id) -} - // ID method provoide a condition as (id) = ? func (engine *Engine) ID(id interface{}) *Session { session := engine.NewSession() @@ -858,46 +742,9 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -// UnMapType removes the datbase mapper of a type -func (engine *Engine) UnMapType(t reflect.Type) { - engine.mutex.Lock() - defer engine.mutex.Unlock() - delete(engine.Tables, t) -} - -func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { - t := v.Type() - engine.mutex.Lock() - defer engine.mutex.Unlock() - table, ok := engine.Tables[t] - if !ok { - var err error - table, err = engine.mapType(v) - if err != nil { - return nil, err - } - - engine.Tables[t] = table - if engine.Cacher != nil { - if v.CanAddr() { - engine.GobRegister(v.Addr().Interface()) - } else { - engine.GobRegister(v.Interface()) - } - } - } - return table, nil -} - -// GobRegister register one struct to gob for cache use -func (engine *Engine) GobRegister(v interface{}) *Engine { - gob.Register(v) - return engine -} - // Table table struct type Table struct { - *core.Table + *schemas.Table Name string } @@ -907,222 +754,9 @@ func (t *Table) IsValid() bool { } // TableInfo get table info according to bean's content -func (engine *Engine) TableInfo(bean interface{}) *Table { - v := rValue(bean) - tb, err := engine.autoMapType(v) - if err != nil { - engine.logger.Error(err) - } - return &Table{tb, engine.TableName(bean)} -} - -func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = indexType - } else { - index := core.NewIndex(indexName, indexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = indexType - } -} - -// TableName table name interface to define customerize table name -type TableName interface { - TableName() string -} - -var ( - tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() -) - -func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { - t := v.Type() - table := core.NewEmptyTable() - table.Type = t - table.Name = getTableName(engine.TableMapper, v) - - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - - for i := 0; i < t.NumField(); i++ { - tag := t.Field(i).Tag - - ormTagStr := tag.Get(engine.TagIdentifier) - var col *core.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() - - if ormTagStr != "" { - col = &core.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: core.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = tagContext{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - engine: engine, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first charactor") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := engine.tagHandlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = core.Type2SQLType(fieldType) - } - engine.dialect.SqlType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = core.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = core.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType core.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - sqlType = core.SQLType{Name: core.Text} - } - } - if _, ok := fieldValue.Interface().(core.Conversion); ok { - sqlType = core.SQLType{Name: core.Text} - } else { - sqlType = core.Type2SQLType(fieldType) - } - col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false - } - - table.AddColumn(col) - - } // end for - - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if engine.Cacher != nil { // !nash! use engine's cacher if provided - engine.logger.Info("enable cache on table:", table.Name) - engine.setCacher(table.Name, engine.Cacher) - } else { - engine.logger.Info("enable LRU cache on table:", table.Name) - engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - engine.logger.Info("disable cache on table:", table.Name) - engine.setCacher(table.Name, nil) - } - - return table, nil +func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) { + v := utils.ReflectValue(bean) + return engine.tagParser.ParseWithCache(v) } // IsTableEmpty if a table has any reocrd @@ -1139,38 +773,24 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { return session.IsTableExist(beanOrTableName) } -// IdOf get id from one struct -// -// Deprecated: use IDOf instead. -func (engine *Engine) IdOf(bean interface{}) core.PK { - return engine.IDOf(bean) -} - // IDOf get id from one struct -func (engine *Engine) IDOf(bean interface{}) core.PK { - return engine.IdOfV(reflect.ValueOf(bean)) +func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) { + return engine.IDOfV(reflect.ValueOf(bean)) } -// IdOfV get id from one value of struct -// -// Deprecated: use IDOfV instead. -func (engine *Engine) IdOfV(rv reflect.Value) core.PK { - return engine.IDOfV(rv) +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) } // IDOfV get id from one value of struct -func (engine *Engine) IDOfV(rv reflect.Value) core.PK { - pk, err := engine.idOfV(rv) - if err != nil { - engine.logger.Error(err) - return nil - } - return pk +func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) { + return engine.idOfV(rv) } -func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { +func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { v := reflect.Indirect(rv) - table, err := engine.autoMapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return nil, err } @@ -1211,10 +831,10 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { return nil, err } } - return core.PK(pk), nil + return schemas.PK(pk), nil } -func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) { +func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) { if col.SQLType.IsNumeric() { n, err := strconv.ParseInt(sid, 10, 64) if err != nil { @@ -1244,8 +864,8 @@ func (engine *Engine) CreateUniques(bean interface{}) error { // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - tableName := engine.TableName(bean) - cacher := engine.getCacher(tableName) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.DelBean(tableName, id) @@ -1256,8 +876,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - tableName := engine.TableName(bean) - cacher := engine.getCacher(tableName) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.ClearBeans(tableName) @@ -1266,6 +886,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { return nil } +// UnMapType remove table from tables cache +func (engine *Engine) UnMapType(t reflect.Type) { + engine.tagParser.ClearCacheTable(t) +} + // Sync the new struct changes to database, this method will automatically add // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. @@ -1274,9 +899,9 @@ func (engine *Engine) Sync(beans ...interface{}) error { defer session.Close() for _, bean := range beans { - v := rValue(bean) - tableNameNoSchema := engine.TableName(bean) - table, err := engine.autoMapType(v) + v := utils.ReflectValue(bean) + tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -1307,12 +932,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) + isExist, err := engine.dialect.IsColumnExist(engine.db, session.ctx, tableNameNoSchema, col.Name) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -1323,16 +948,16 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -1341,13 +966,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } } - } else if index.Type == core.IndexType { + } else if index.Type == schemas.IndexType { isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -1577,112 +1202,36 @@ func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, er // ImportFile SQL DDL file func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { - file, err := os.Open(ddlPath) - if err != nil { - return nil, err - } - defer file.Close() - return engine.Import(file) + session := engine.NewSession() + defer session.Close() + return session.ImportFile(ddlPath) } // Import SQL DDL from io.Reader func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(r) - - semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, ';'); i >= 0 { - return i + 1, data[0:i], nil - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } - - scanner.Split(semiColSpliter) - - for scanner.Scan() { - query := strings.Trim(scanner.Text(), " \t\n\r") - if len(query) > 0 { - // engine.logSQL(query) - engine.logger.Infof("[SQL] %v", query) - result, err := engine.DB().Exec(query) - results = append(results, result) - if err != nil { - return nil, err - } - } - } - - return results, lastError + session := engine.NewSession() + defer session.Close() + return session.Import(r) } // nowTime return current time -func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) { +func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { t := time.Now() var tz = engine.DatabaseTZ if !col.DisableTimeZone && col.TimeZone != nil { tz = col.TimeZone } - return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) -} - -func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { - if t.IsZero() { - if col.Nullable { - return nil - } - return "" - } - - if col.TimeZone != nil { - return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone)) - } - return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ)) -} - -// formatTime format time as column type -func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { - switch sqlTypeName { - case core.Time: - s := t.Format("2006-01-02 15:04:05") // time.RFC3339 - v = s[11:19] - case core.Date: - v = t.Format("2006-01-02") - case core.DateTime, core.TimeStamp, core.Varchar: // !DarthPestilane! format time when sqlTypeName is core.Varchar. - v = t.Format("2006-01-02 15:04:05.999") - if engine.dialect.DBType() == "sqlite3" { - v = t.UTC().Format("2006-01-02 15:04:05.999") - } - case core.TimeStampz: - if engine.dialect.DBType() == core.MSSQL { - v = t.Format("2006-01-02T15:04:05.9999999Z07:00") - } else { - v = t.Format(time.RFC3339Nano) - } - case core.BigInt, core.Int: - v = t.Unix() - default: - v = t - } - return + return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } // GetColumnMapper returns the column name mapper -func (engine *Engine) GetColumnMapper() core.IMapper { - return engine.ColumnMapper +func (engine *Engine) GetColumnMapper() names.Mapper { + return engine.tagParser.GetColumnMapper() } // GetTableMapper returns the table name mapper -func (engine *Engine) GetTableMapper() core.IMapper { - return engine.TableMapper +func (engine *Engine) GetTableMapper() names.Mapper { + return engine.tagParser.GetTableMapper() } // GetTZLocation returns time zone of the application @@ -1707,7 +1256,7 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { // SetSchema sets the schema of database func (engine *Engine) SetSchema(schema string) { - engine.dialect.URI().Schema = schema + engine.dialect.URI().SetSchema(schema) } // Unscoped always disable struct tag "deleted" @@ -1716,3 +1265,47 @@ func (engine *Engine) Unscoped() *Session { session.isAutoClose = true return session.Unscoped() } + +func (engine *Engine) tbNameWithSchema(v string) string { + return dialects.TableNameWithSchema(engine.dialect, v) +} + +// Context creates a session with the context +func (engine *Engine) Context(ctx context.Context) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.Context(ctx) +} + +// SetDefaultContext set the default context +func (engine *Engine) SetDefaultContext(ctx context.Context) { + engine.defaultContext = ctx +} + +// PingContext tests if database is alive +func (engine *Engine) PingContext(ctx context.Context) error { + session := engine.NewSession() + defer session.Close() + return session.PingContext(ctx) +} + +// Transaction Execute sql wrapped in a transaction(abbr as tx), tx will automatic commit if no errors occurred +func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interface{}, error) { + session := engine.NewSession() + defer session.Close() + + if err := session.Begin(); err != nil { + return nil, err + } + + result, err := f(session) + if err != nil { + return nil, err + } + + if err := session.Commit(); err != nil { + return nil, err + } + + return result, nil +} diff --git a/engine_cond.go b/engine_cond.go deleted file mode 100644 index d56905a..0000000 --- a/engine_cond.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "github.com/xormplus/builder" - "github.com/xormplus/core" -) - -func (engine *Engine) buildConds(table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - if !strings.Contains(err.Error(), "is not valid") { - engine.logger.Warn(err) - } - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - conds = append(conds, engine.CondDeleted(col)) - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } else { - val = fieldValue.Interface() - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} diff --git a/engine_context.go b/engine_context.go deleted file mode 100644 index c6cbb76..0000000 --- a/engine_context.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import "context" - -// Context creates a session with the context -func (engine *Engine) Context(ctx context.Context) *Session { - session := engine.NewSession() - session.isAutoClose = true - return session.Context(ctx) -} - -// SetDefaultContext set the default context -func (engine *Engine) SetDefaultContext(ctx context.Context) { - engine.defaultContext = ctx -} - -// PingContext tests if database is alive -func (engine *Engine) PingContext(ctx context.Context) error { - session := engine.NewSession() - defer session.Close() - return session.PingContext(ctx) -} diff --git a/engine_context_test.go b/engine_context_test.go deleted file mode 100644 index 1a3276c..0000000 --- a/engine_context_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestPingContext(t *testing.T) { - assert.NoError(t, prepareEngine()) - - ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) - defer canceled() - - time.Sleep(time.Nanosecond) - - err := testEngine.(*Engine).PingContext(ctx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "context deadline exceeded") -} diff --git a/engine_group.go b/engine_group.go index 4b2d79b..a492529 100644 --- a/engine_group.go +++ b/engine_group.go @@ -8,7 +8,10 @@ import ( "context" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" ) // EngineGroup defines an engine group @@ -109,10 +112,10 @@ func (eg *EngineGroup) Ping() error { } // SetColumnMapper set the column name mapping rule -func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) { - eg.Engine.ColumnMapper = mapper +func (eg *EngineGroup) SetColumnMapper(mapper names.Mapper) { + eg.Engine.SetColumnMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ColumnMapper = mapper + eg.slaves[i].SetColumnMapper(mapper) } } @@ -125,7 +128,7 @@ func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) { } // SetDefaultCacher set the default cacher -func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) { +func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) { eg.Engine.SetDefaultCacher(cacher) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetDefaultCacher(cacher) @@ -133,7 +136,7 @@ func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) { } // SetLogger set the new logger -func (eg *EngineGroup) SetLogger(logger core.ILogger) { +func (eg *EngineGroup) SetLogger(logger interface{}) { eg.Engine.SetLogger(logger) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetLogger(logger) @@ -141,7 +144,7 @@ func (eg *EngineGroup) SetLogger(logger core.ILogger) { } // SetLogLevel sets the logger level -func (eg *EngineGroup) SetLogLevel(level core.LogLevel) { +func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { eg.Engine.SetLogLevel(level) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetLogLevel(level) @@ -149,7 +152,7 @@ func (eg *EngineGroup) SetLogLevel(level core.LogLevel) { } // SetMapper set the name mapping rules -func (eg *EngineGroup) SetMapper(mapper core.IMapper) { +func (eg *EngineGroup) SetMapper(mapper names.Mapper) { eg.Engine.SetMapper(mapper) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetMapper(mapper) @@ -158,17 +161,17 @@ func (eg *EngineGroup) SetMapper(mapper core.IMapper) { // SetMaxIdleConns set the max idle connections on pool, default is 2 func (eg *EngineGroup) SetMaxIdleConns(conns int) { - eg.Engine.db.SetMaxIdleConns(conns) + eg.Engine.DB().SetMaxIdleConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].db.SetMaxIdleConns(conns) + eg.slaves[i].DB().SetMaxIdleConns(conns) } } // SetMaxOpenConns is only available for go 1.2+ func (eg *EngineGroup) SetMaxOpenConns(conns int) { - eg.Engine.db.SetMaxOpenConns(conns) + eg.Engine.DB().SetMaxOpenConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].db.SetMaxOpenConns(conns) + eg.slaves[i].DB().SetMaxOpenConns(conns) } } @@ -178,19 +181,19 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup { return eg } -// SetTableMapper set the table name mapping rule -func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) { - eg.Engine.TableMapper = mapper +// SetQuotePolicy sets the special quote policy +func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + eg.Engine.SetQuotePolicy(quotePolicy) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].TableMapper = mapper + eg.slaves[i].SetQuotePolicy(quotePolicy) } } -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (eg *EngineGroup) ShowExecTime(show ...bool) { - eg.Engine.ShowExecTime(show...) +// SetTableMapper set the table name mapping rule +func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { + eg.Engine.SetTableMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ShowExecTime(show...) + eg.slaves[i].SetTableMapper(mapper) } } diff --git a/engine_group_policy.go b/engine_group_policy.go index 5b56e89..1def8ce 100644 --- a/engine_group_policy.go +++ b/engine_group_policy.go @@ -51,6 +51,7 @@ func WeightRandomPolicy(weights []int) GroupPolicyHandler { } } +// RoundRobinPolicy returns a group policy handler func RoundRobinPolicy() GroupPolicyHandler { var pos = -1 var lock sync.Mutex @@ -68,6 +69,7 @@ func RoundRobinPolicy() GroupPolicyHandler { } } +// WeightRoundRobinPolicy returns a group policy handler func WeightRoundRobinPolicy(weights []int) GroupPolicyHandler { var rands = make([]int, 0, len(weights)) for i := 0; i < len(weights); i++ { diff --git a/engine_table.go b/engine_table.go deleted file mode 100644 index 7b99d74..0000000 --- a/engine_table.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "reflect" - "strings" - - "github.com/xormplus/core" -) - -// tbNameWithSchema will automatically add schema prefix on table name -func (engine *Engine) tbNameWithSchema(v string) string { - // Add schema name as prefix of table name. - // Only for postgres database. - if engine.dialect.DBType() == core.POSTGRES && - engine.dialect.URI().Schema != "" && - engine.dialect.URI().Schema != postgresPublicSchema && - strings.Index(v, ".") == -1 { - return engine.dialect.URI().Schema + "." + v - } - return v -} - -func isSubQuery(tbName string) bool { - const selStr = "select" - if len(tbName) <= len(selStr)+1 { - return false - } - - return strings.EqualFold(tbName[:len(selStr)], selStr) || strings.EqualFold(tbName[:len(selStr)+1], "("+selStr) -} - -// TableName returns table name with schema prefix if has -func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { - tbName := engine.tbNameNoSchema(bean) - if len(includeSchema) > 0 && includeSchema[0] && !isSubQuery(tbName) { - tbName = engine.tbNameWithSchema(tbName) - } - - return tbName -} - -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - -func (engine *Engine) tbNameNoSchema(tablename interface{}) string { - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) - } else if len(t) == 1 { - return engine.Quote(t[0]) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - switch f.(type) { - case string: - table = f.(string) - case TableName: - table = f.(TableName).TableName() - default: - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.Struct { - table = getTableName(engine.TableMapper, v) - } else { - table = engine.Quote(fmt.Sprintf("%v", f)) - } - } - } - if l > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(table), - engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - return engine.Quote(table) - } - case TableName: - return tablename.(TableName).TableName() - case string: - return tablename.(string) - case reflect.Value: - v := tablename.(reflect.Value) - return getTableName(engine.TableMapper, v) - default: - v := rValue(tablename) - t := v.Type() - if t.Kind() == reflect.Struct { - return getTableName(engine.TableMapper, v) - } - return engine.Quote(fmt.Sprintf("%v", tablename)) - } - return "" -} diff --git a/engine_test.go b/engine_test.go deleted file mode 100644 index c33f500..0000000 --- a/engine_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAutoIncrTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TestAutoIncr1 struct { - Id int64 - } - - tb := testEngine.TableInfo(new(TestAutoIncr1)) - cols := tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.True(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr2 struct { - Id int64 `xorm:"id"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr2)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr3 struct { - Id int64 `xorm:"'ID'"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr3)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "ID", cols[0].Name) - - type TestAutoIncr4 struct { - Id int64 `xorm:"pk"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr4)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) -} - -func TestQuoteTo(t *testing.T) { - - test := func(t *testing.T, expected string, value string) { - buf := &strings.Builder{} - quoteTo(buf, "[]", value) - assert.EqualValues(t, expected, buf.String()) - } - - test(t, "[mytable]", "mytable") - test(t, "[mytable]", "`mytable`") - test(t, "[mytable]", `[mytable]`) - - test(t, `["mytable"]`, `"mytable"`) - - test(t, "[myschema].[mytable]", "myschema.mytable") - test(t, "[myschema].[mytable]", "`myschema`.mytable") - test(t, "[myschema].[mytable]", "myschema.`mytable`") - test(t, "[myschema].[mytable]", "`myschema`.`mytable`") - test(t, "[myschema].[mytable]", `[myschema].mytable`) - test(t, "[myschema].[mytable]", `myschema.[mytable]`) - test(t, "[myschema].[mytable]", `[myschema].[mytable]`) - - test(t, `["myschema].[mytable"]`, `"myschema.mytable"`) - - buf := &strings.Builder{} - quoteTo(buf, "", "noquote") - assert.EqualValues(t, "noquote", buf.String()) -} diff --git a/engineplus.go b/engineplus.go index 63f9f2f..b694872 100644 --- a/engineplus.go +++ b/engineplus.go @@ -13,7 +13,7 @@ func (engine *Engine) SqlMapClient(sqlTagName string, args ...interface{}) *Sess session := engine.NewSession() session.isAutoClose = true session.isSqlFunc = true - return session.Sql(engine.SqlMap.Sql[sqlTagName], args...) + return session.SQL(engine.SqlMap.Sql[sqlTagName], args...) } func (engine *Engine) SqlTemplateClient(sqlTagName string, args ...interface{}) *Session { diff --git a/error.go b/error.go index 10acc83..1eb9431 100644 --- a/error.go +++ b/error.go @@ -10,6 +10,8 @@ import ( ) var ( + // ErrPtrSliceType represents a type error + ErrPtrSliceType = errors.New("A point to a slice is needed") // ErrParamsType params error ErrParamsType = errors.New("Params type error") ErrParamsFormat = errors.New("Params format error") @@ -24,10 +26,12 @@ var ( ErrTransactionDefinition = errors.New("Transaction definition error.") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") + // ErrNeedDeletedCond delete needs less one condition error ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("Not implemented") + // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") // ErrNeedMoreArguments need more arguments diff --git a/fswatcher.go b/fswatcher.go index 4beaa18..8cfc3a4 100644 --- a/fswatcher.go +++ b/fswatcher.go @@ -24,21 +24,21 @@ func (engine *Engine) StartFSWatcher() error { if strings.HasSuffix(event.Name, engine.SqlTemplate.Extension()) { err = engine.ReloadSqlTemplate(event.Name) if err != nil { - engine.logger.Error(err) + engine.logger.Errorf("%v", err) } } if strings.HasSuffix(event.Name, engine.SqlMap.Extension["xml"]) || strings.HasSuffix(event.Name, engine.SqlMap.Extension["json"]) || strings.HasSuffix(event.Name, engine.SqlMap.Extension["xsql"]) { err = engine.reloadSqlMap(event.Name) if err != nil { - engine.logger.Error(err) + engine.logger.Errorf("%v", err) } } } case err := <-engine.watcher.Errors: if err != nil { - engine.logger.Error(err) + engine.logger.Errorf("%v", err) } } } diff --git a/helpers.go b/helpers.go deleted file mode 100644 index b4bd449..0000000 --- a/helpers.go +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/xormplus/core" -) - -// str2PK convert string value to primary key value according to tp -func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { - var err error - var result interface{} - var defReturn = reflect.Zero(tp) - - switch tp.Kind() { - case reflect.Int: - result, err = strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) - } - case reflect.Int8: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) - } - result = int8(x) - case reflect.Int16: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) - } - result = int16(x) - case reflect.Int32: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) - } - result = int32(x) - case reflect.Int64: - result, err = strconv.ParseInt(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) - } - case reflect.Uint: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) - } - result = uint(x) - case reflect.Uint8: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) - } - result = uint8(x) - case reflect.Uint16: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) - } - result = uint16(x) - case reflect.Uint32: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) - } - result = uint32(x) - case reflect.Uint64: - result, err = strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) - } - case reflect.String: - result = s - default: - return defReturn, errors.New("unsupported convert type") - } - return reflect.ValueOf(result).Convert(tp), nil -} - -func str2PK(s string, tp reflect.Type) (interface{}, error) { - v, err := str2PKValue(s, tp) - if err != nil { - return nil, err - } - return v.Interface(), nil -} - -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - -type zeroable interface { - IsZero() bool -} - -func isZero(k interface{}) bool { - switch k.(type) { - case int: - return k.(int) == 0 - case int8: - return k.(int8) == 0 - case int16: - return k.(int16) == 0 - case int32: - return k.(int32) == 0 - case int64: - return k.(int64) == 0 - case uint: - return k.(uint) == 0 - case uint8: - return k.(uint8) == 0 - case uint16: - return k.(uint16) == 0 - case uint32: - return k.(uint32) == 0 - case uint64: - return k.(uint64) == 0 - case float32: - return k.(float32) == 0 - case float64: - return k.(float64) == 0 - case bool: - return k.(bool) == false - case string: - return k.(string) == "" - case zeroable: - return k.(zeroable).IsZero() - } - return false -} - -func isZeroValue(v reflect.Value) bool { - if isZero(v.Interface()) { - return true - } - switch v.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: - return v.IsNil() - } - return false -} - -func isStructZero(v reflect.Value) bool { - if !v.IsValid() { - return true - } - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - switch field.Kind() { - case reflect.Ptr: - field = field.Elem() - fallthrough - case reflect.Struct: - if !isStructZero(field) { - return false - } - default: - if field.CanInterface() && !isZero(field.Interface()) { - return false - } - } - } - return true -} - -func isArrayValueZero(v reflect.Value) bool { - if !v.IsValid() || v.Len() == 0 { - return true - } - - for i := 0; i < v.Len(); i++ { - if !isZero(v.Index(i).Interface()) { - return false - } - } - - return true -} - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} - -func isPKZero(pk core.PK) bool { - for _, k := range pk { - if isZero(k) { - return true - } - } - return false -} - -func indexNoCase(s, sep string) int { - return strings.Index(strings.ToLower(s), strings.ToLower(sep)) -} - -func splitNoCase(s, sep string) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.Split(s, s[idx:idx+len(sep)]) -} - -func splitNNoCase(s, sep string, n int) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.SplitN(s, s[idx:idx+len(sep)], n) -} - -func makeArray(elem string, count int) []string { - res := make([]string, count) - for i := 0; i < count; i++ { - res[i] = elem - } - return res -} - -func rValue(bean interface{}) reflect.Value { - return reflect.Indirect(reflect.ValueOf(bean)) -} - -func rType(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - // return reflect.TypeOf(sliceValue.Interface()) - return sliceValue.Type() -} - -func structName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() -} - -func sliceEq(left, right []string) bool { - if len(left) != len(right) { - return false - } - sort.Sort(sort.StringSlice(left)) - sort.Sort(sort.StringSlice(right)) - for i := 0; i < len(left); i++ { - if left[i] != right[i] { - return false - } - } - return true -} - -func indexName(tableName, idxName string) string { - return fmt.Sprintf("IDX_%v_%v", tableName, idxName) -} - -func eraseAny(value string, strToErase ...string) string { - if len(strToErase) == 0 { - return value - } - var replaceSeq []string - for _, s := range strToErase { - replaceSeq = append(replaceSeq, s, "") - } - - replacer := strings.NewReplacer(replaceSeq...) - - return replacer.Replace(value) -} - -func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string { - for i := range cols { - cols[i] = quoteFunc(cols[i]) - } - return strings.Join(cols, sep+" ") -} diff --git a/helpers_plus.go b/helpers_plus.go index eed8515..34f99ae 100644 --- a/helpers_plus.go +++ b/helpers_plus.go @@ -5,7 +5,8 @@ import ( "reflect" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" ) func reflect2objectWithDateFormat(rawValue *reflect.Value, dateFormat string) (value interface{}, err error) { @@ -30,8 +31,8 @@ func reflect2objectWithDateFormat(rawValue *reflect.Value, dateFormat string) (v } // time type case reflect.Struct: - if aa.ConvertibleTo(core.TimeType) { - value = vv.Convert(core.TimeType).Interface().(time.Time).Format(dateFormat) + if aa.ConvertibleTo(schemas.TimeType) { + value = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(dateFormat) } else { err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) } @@ -177,8 +178,8 @@ func reflect2object(rawValue *reflect.Value) (value interface{}, err error) { } // time type case reflect.Struct: - if aa.ConvertibleTo(core.TimeType) { - value = vv.Convert(core.TimeType).Interface().(time.Time) + if aa.ConvertibleTo(schemas.TimeType) { + value = vv.Convert(schemas.TimeType).Interface().(time.Time) } else { err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) } diff --git a/helpers_test.go b/helpers_test.go deleted file mode 100644 index caf7b9f..0000000 --- a/helpers_test.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEraseAny(t *testing.T) { - raw := "SELECT * FROM `table`.[table_name]" - assert.EqualValues(t, raw, eraseAny(raw)) - assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`")) - assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]")) -} - -func TestQuoteColumns(t *testing.T) { - cols := []string{"f1", "f2", "f3"} - quoteFunc := func(value string) string { - return "[" + value + "]" - } - - assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ",")) -} diff --git a/helpler_time.go b/helpler_time.go deleted file mode 100644 index f4013e2..0000000 --- a/helpler_time.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "time" - -const ( - zeroTime0 = "0000-00-00 00:00:00" - zeroTime1 = "0001-01-01 00:00:00" -) - -func formatTime(t time.Time) string { - return t.Format("2006-01-02 15:04:05") -} - -func isTimeZero(t time.Time) bool { - return t.IsZero() || formatTime(t) == zeroTime0 || - formatTime(t) == zeroTime1 -} diff --git a/cache_test.go b/integrations/cache_test.go similarity index 88% rename from cache_test.go rename to integrations/cache_test.go index 5f138f2..df501fa 100644 --- a/cache_test.go +++ b/integrations/cache_test.go @@ -2,17 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "time" + "github.com/xormplus/xorm/caches" + "github.com/stretchr/testify/assert" ) func TestCacheFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox struct { Id int64 `xorm:"pk"` @@ -20,8 +22,8 @@ func TestCacheFind(t *testing.T) { Password string } - oldCacher := testEngine.Cacher - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + oldCacher := testEngine.GetDefaultCacher() + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox))) @@ -87,7 +89,7 @@ func TestCacheFind(t *testing.T) { } func TestCacheFind2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox2 struct { Id uint64 `xorm:"pk"` @@ -95,8 +97,8 @@ func TestCacheFind2(t *testing.T) { Password string } - oldCacher := testEngine.Cacher - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + oldCacher := testEngine.GetDefaultCacher() + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox2))) @@ -138,7 +140,7 @@ func TestCacheFind2(t *testing.T) { } func TestCacheGet(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox3 struct { Id uint64 @@ -146,8 +148,8 @@ func TestCacheGet(t *testing.T) { Password string } - oldCacher := testEngine.Cacher - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + oldCacher := testEngine.GetDefaultCacher() + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox3))) diff --git a/integrations/engine_group_test.go b/integrations/engine_group_test.go new file mode 100644 index 0000000..99610c1 --- /dev/null +++ b/integrations/engine_group_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/xormplus/xorm" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/schemas" + + "github.com/stretchr/testify/assert" +) + +func TestEngineGroup(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + master := testEngine.(*xorm.Engine) + if master.Dialect().URI().DBType == schemas.SQLITE { + t.Skip() + return + } + + eg, err := xorm.NewEngineGroup(master, []*xorm.Engine{master}) + assert.NoError(t, err) + + eg.SetMaxIdleConns(10) + eg.SetMaxOpenConns(100) + eg.SetTableMapper(master.GetTableMapper()) + eg.SetColumnMapper(master.GetColumnMapper()) + eg.SetLogLevel(log.LOG_INFO) + eg.ShowSQL(true) +} diff --git a/integrations/engine_test.go b/integrations/engine_test.go new file mode 100644 index 0000000..c2f988f --- /dev/null +++ b/integrations/engine_test.go @@ -0,0 +1,141 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/xormplus/xorm" + "github.com/xormplus/xorm/schemas" + + _ "github.com/denisenkom/go-mssqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + _ "github.com/ziutek/mymysql/godrv" +) + +func TestPing(t *testing.T) { + if err := testEngine.Ping(); err != nil { + t.Fatal(err) + } +} + +func TestPingContext(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) + defer canceled() + + time.Sleep(time.Nanosecond) + + err := testEngine.(*xorm.Engine).PingContext(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestAutoTransaction(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestTx struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(TestTx))) + + engine := testEngine.(*xorm.Engine) + + // will success + engine.Transaction(func(session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTx{Msg: "hi"}) + assert.NoError(t, err) + + return nil, nil + }) + + has, err := engine.Exist(&TestTx{Msg: "hi"}) + assert.NoError(t, err) + assert.EqualValues(t, true, has) + + // will rollback + _, err = engine.Transaction(func(session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTx{Msg: "hello"}) + assert.NoError(t, err) + + return nil, fmt.Errorf("rollback") + }) + assert.Error(t, err) + + has, err = engine.Exist(&TestTx{Msg: "hello"}) + assert.NoError(t, err) + assert.EqualValues(t, false, has) +} + +func assertSync(t *testing.T, beans ...interface{}) { + for _, bean := range beans { + t.Run(testEngine.TableName(bean, true), func(t *testing.T) { + assert.NoError(t, testEngine.DropTables(bean)) + assert.NoError(t, testEngine.Sync2(bean)) + }) + } +} + +func TestDump(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpStruct struct { + Id int64 + Name string + } + + assertSync(t, new(TestDumpStruct)) + + testEngine.Insert([]TestDumpStruct{ + {Name: "1"}, + {Name: "2\n"}, + {Name: "3;"}, + {Name: "4\n;\n''"}, + {Name: "5'\n"}, + }) + + fp := fmt.Sprintf("%v.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + assert.NoError(t, testEngine.DumpAllToFile(fp)) + + assert.NoError(t, PrepareEngine()) + + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err := sess.ImportFile(fp) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + + for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + name := fmt.Sprintf("dump_%v.sql", tp) + t.Run(name, func(t *testing.T) { + assert.NoError(t, testEngine.DumpAllToFile(name, tp)) + }) + } +} + +func TestSetSchema(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.POSTGRES { + oldSchema := testEngine.Dialect().URI().Schema + testEngine.SetSchema("my_schema") + assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema) + testEngine.SetSchema(oldSchema) + assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema) + } +} diff --git a/types.go b/integrations/main_test.go similarity index 55% rename from types.go rename to integrations/main_test.go index feaf235..225ae45 100644 --- a/types.go +++ b/integrations/main_test.go @@ -2,15 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "reflect" - - "github.com/xormplus/core" + "testing" ) -var ( - ptrPkType = reflect.TypeOf(&core.PK{}) - pkType = reflect.TypeOf(core.PK{}) -) +func TestMain(m *testing.M) { + MainTest(m) +} diff --git a/processors_test.go b/integrations/processors_test.go similarity index 90% rename from processors_test.go rename to integrations/processors_test.go index d1efc04..c022d88 100644 --- a/processors_test.go +++ b/integrations/processors_test.go @@ -2,18 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" "fmt" "testing" + "github.com/xormplus/xorm" + "github.com/stretchr/testify/assert" ) func TestBefore_Get(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type BeforeTable struct { Id int64 @@ -40,7 +42,7 @@ func TestBefore_Get(t *testing.T) { } func TestBefore_Find(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type BeforeTable2 struct { Id int64 @@ -101,7 +103,7 @@ func (p *ProcessorsStruct) BeforeDelete() { p.B4DeleteFlag = 1 } -func (p *ProcessorsStruct) BeforeSet(col string, cell Cell) { +func (p *ProcessorsStruct) BeforeSet(col string, cell xorm.Cell) { p.BeforeSetFlag = p.BeforeSetFlag + 1 } @@ -117,25 +119,19 @@ func (p *ProcessorsStruct) AfterDelete() { p.AfterDeletedFlag = 1 } -func (p *ProcessorsStruct) AfterSet(col string, cell Cell) { +func (p *ProcessorsStruct) AfterSet(col string, cell xorm.Cell) { p.AfterSetFlag = p.AfterSetFlag + 1 } func TestProcessors(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p := &ProcessorsStruct{} err = testEngine.CreateTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) b4InsertFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { @@ -259,42 +255,22 @@ func TestProcessors(t *testing.T) { _, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } + assert.False(t, p.B4UpdateFlag == 0, "B4UpdateFlag not set") + assert.False(t, p.AfterUpdatedFlag == 0, "AfterUpdatedFlag not set") + assert.False(t, p.B4UpdateViaExt == 0, "B4UpdateViaExt not set") + assert.False(t, p.AfterUpdatedViaExt == 0, "AfterUpdatedViaExt not set") p2 = &ProcessorsStruct{} has, err = testEngine.ID(p.Id).Get(p2) assert.NoError(t, err) assert.True(t, has) - if p2.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) - } - if p2.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) - } - if p2.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) - } - if p2.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) - } + assert.False(t, p2.B4UpdateFlag == 0, "B4UpdateFlag not set") + assert.False(t, p2.AfterUpdatedFlag != 0, fmt.Sprintf("AfterUpdatedFlag is set: %d", p.AfterUpdatedFlag)) + assert.False(t, p2.B4UpdateViaExt == 0, "B4UpdateViaExt not set") + assert.False(t, p2.AfterUpdatedViaExt != 0, fmt.Sprintf("AfterUpdatedViaExt is set: %d", p.AfterUpdatedViaExt)) + assert.False(t, p2.BeforeSetFlag != 9, fmt.Sprintf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + assert.False(t, p2.AfterSetFlag != 9, fmt.Sprintf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) // -- // test delete processors @@ -382,7 +358,7 @@ func TestProcessors(t *testing.T) { } func TestProcessorsTx(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&ProcessorsStruct{}) assert.NoError(t, err) @@ -450,12 +426,7 @@ func TestProcessorsTx(t *testing.T) { p2 := &ProcessorsStruct{} _, err = testEngine.ID(p.Id).Get(p2) assert.NoError(t, err) - - if p2.Id > 0 { - err = errors.New("tx got committed upon insert!?") - t.Error(err) - panic(err) - } + assert.False(t, p2.Id > 0, "tx got committed upon insert!?") // -- // test insert processors with tx commit @@ -516,7 +487,7 @@ func TestProcessorsTx(t *testing.T) { t.Error(errors.New("AfterInsertedViaExt is set")) } - insertedId := p2.Id + insertedID := p2.Id // -- // test update processors with tx rollback @@ -544,7 +515,7 @@ func TestProcessorsTx(t *testing.T) { p = p2 // reset - _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + _, err = session.ID(insertedID).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -579,7 +550,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p2.B4UpdateFlag != 0 { @@ -603,7 +574,7 @@ func TestProcessorsTx(t *testing.T) { err = session.Begin() assert.NoError(t, err) - p = &ProcessorsStruct{Id: insertedId} + p = &ProcessorsStruct{Id: insertedID} _, err = session.Update(p) assert.NoError(t, err) @@ -642,7 +613,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} - _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + _, err = session.ID(insertedID).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -676,7 +647,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -718,7 +689,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} // reset - _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + _, err = session.ID(insertedID).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) assert.NoError(t, err) if p.B4DeleteFlag == 0 { @@ -752,7 +723,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p2.B4DeleteFlag != 0 { @@ -778,7 +749,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} - _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + _, err = session.ID(insertedID).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) assert.NoError(t, err) if p.B4DeleteFlag == 0 { @@ -819,7 +790,7 @@ func TestProcessorsTx(t *testing.T) { err = session.Begin() assert.NoError(t, err) - p = &ProcessorsStruct{Id: insertedId} + p = &ProcessorsStruct{Id: insertedID} _, err = session.Delete(p) assert.NoError(t, err) @@ -846,7 +817,6 @@ func TestProcessorsTx(t *testing.T) { t.Error(errors.New("AfterUpdatedFlag set")) } session.Close() - // -- } type AfterLoadStructA struct { @@ -862,19 +832,19 @@ type AfterLoadStructB struct { Err error `xorm:"-"` } -func (s *AfterLoadStructB) AfterLoad(session *Session) { +func (s *AfterLoadStructB) AfterLoad(session *xorm.Session) { has, err := session.ID(s.AId).NoAutoCondition().Get(&s.A) if err != nil { s.Err = err return } if !has { - s.Err = ErrNotExist + s.Err = xorm.ErrNotExist } } func TestAfterLoadProcessor(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AfterLoadStructA), new(AfterLoadStructB)) @@ -925,7 +895,7 @@ func (a *AfterInsertStruct) AfterInsert() { } func TestAfterInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AfterInsertStruct)) diff --git a/rows_test.go b/integrations/rows_test.go similarity index 87% rename from rows_test.go rename to integrations/rows_test.go index af33386..f68030a 100644 --- a/rows_test.go +++ b/integrations/rows_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" @@ -11,7 +11,7 @@ import ( ) func TestRows(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserRows struct { Id int64 @@ -85,7 +85,7 @@ func TestRows(t *testing.T) { } func TestRowsMyTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserRowsMyTable struct { Id int64 @@ -104,7 +104,6 @@ func TestRowsMyTableName(t *testing.T) { rows, err := testEngine.Table(tableName).Rows(new(UserRowsMyTable)) assert.NoError(t, err) - defer rows.Close() cnt = 0 user := new(UserRowsMyTable) @@ -114,6 +113,21 @@ func TestRowsMyTableName(t *testing.T) { cnt++ } assert.EqualValues(t, 1, cnt) + + rows.Close() + + rows, err = testEngine.Table(tableName).Rows(&UserRowsMyTable{ + Id: 2, + }) + assert.NoError(t, err) + cnt = 0 + user = new(UserRowsMyTable) + for rows.Next() { + err = rows.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 0, cnt) } type UserRowsSpecTable struct { @@ -126,7 +140,7 @@ func (UserRowsSpecTable) TableName() string { } func TestRowsSpecTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(UserRowsSpecTable))) cnt, err := testEngine.Insert(&UserRowsSpecTable{ diff --git a/session_cols_test.go b/integrations/session_cols_test.go similarity index 93% rename from session_cols_test.go rename to integrations/session_cols_test.go index f501272..9d19991 100644 --- a/session_cols_test.go +++ b/integrations/session_cols_test.go @@ -2,18 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "github.com/stretchr/testify/assert" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" ) func TestSetExpr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserExprIssue struct { Id int64 @@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) { assert.EqualValues(t, 1, cnt) var not = "NOT" - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { not = "~" } cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) @@ -64,7 +64,7 @@ func TestSetExpr(t *testing.T) { } func TestCols(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ColsTable struct { Id int64 @@ -96,7 +96,7 @@ func TestCols(t *testing.T) { } func TestMustCol(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CustomerUpdate struct { Id int64 `form:"id" json:"id"` diff --git a/session_cond_test.go b/integrations/session_cond_test.go similarity index 88% rename from session_cond_test.go rename to integrations/session_cond_test.go index c45b682..3eca46e 100644 --- a/session_cond_test.go +++ b/integrations/session_cond_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" @@ -14,7 +14,7 @@ import ( ) func TestBuilder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) const ( OpEqual int = iota @@ -102,7 +102,7 @@ func TestBuilder(t *testing.T) { } func TestIn(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(Userinfo))) cnt, err := testEngine.Insert([]Userinfo{ @@ -137,15 +137,13 @@ func TestIn(t *testing.T) { idsStr = idsStr[:len(idsStr)-1] users := make([]Userinfo, 0) - err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) + err = testEngine.In("id", ids[0], ids[1], ids[2]).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) - err = testEngine.In("(id)", ids).Find(&users) + err = testEngine.In("id", ids).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) for _, user := range users { @@ -161,9 +159,8 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) + err = testEngine.Where(department+" = ?", "dev").In("id", idsInterface...).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) for _, user := range users { @@ -175,11 +172,10 @@ func TestIn(t *testing.T) { dev := testEngine.GetColumnMapper().Obj2Table("Dev") - err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) + err = testEngine.In("id", 1).In("id", 2).In(department, dev).Find(&users) assert.NoError(t, err) - fmt.Println(users) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev-"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -189,17 +185,17 @@ func TestIn(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "dev-", user.Departname) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) + cnt, err = testEngine.In("id", ids[1]).Delete(&Userinfo{}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } func TestFindAndCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type FindAndCount struct { Id int64 diff --git a/session_delete_test.go b/integrations/session_delete_test.go similarity index 93% rename from session_delete_test.go rename to integrations/session_delete_test.go index c771a70..d8da7d1 100644 --- a/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -2,18 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "time" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/schemas" + "github.com/stretchr/testify/assert" - "github.com/xormplus/core" ) func TestDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoDelete struct { Uid int64 `xorm:"id pk not null autoincr"` @@ -26,7 +28,7 @@ func TestDelete(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") @@ -38,7 +40,7 @@ func TestDelete(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -69,7 +71,7 @@ func TestDelete(t *testing.T) { } func TestDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Deleted struct { Id int64 `xorm:"pk"` @@ -156,10 +158,10 @@ func TestDeleted(t *testing.T) { } func TestCacheDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldCacher := testEngine.GetDefaultCacher() - cacher := NewLRUCacher(NewMemoryStore(), 1000) + cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000) testEngine.SetDefaultCacher(cacher) type CacheDeleteStruct struct { @@ -188,7 +190,7 @@ func TestCacheDelete(t *testing.T) { } func TestUnscopeDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UnscopeDeleteStruct struct { Id int64 diff --git a/session_exist_test.go b/integrations/session_exist_test.go similarity index 86% rename from session_exist_test.go rename to integrations/session_exist_test.go index 2792654..6247c91 100644 --- a/session_exist_test.go +++ b/integrations/session_exist_test.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" ) func TestExistStruct(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type RecordExist struct { Id int64 @@ -76,7 +78,7 @@ func TestExistStruct(t *testing.T) { } func TestExistStructForJoin(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Number struct { Id int64 @@ -181,3 +183,26 @@ func TestExistStructForJoin(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestExistContext(t *testing.T) { + type ContextQueryStruct struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(ContextQueryStruct)) + + _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + + time.Sleep(time.Nanosecond) + + has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + assert.False(t, has) +} diff --git a/session_find_test.go b/integrations/session_find_test.go similarity index 66% rename from session_find_test.go rename to integrations/session_find_test.go index c5311b1..cc6633c 100644 --- a/session_find_test.go +++ b/integrations/session_find_test.go @@ -2,20 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" - "fmt" "testing" "time" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/names" + "github.com/stretchr/testify/assert" - "github.com/xormplus/core" ) func TestJoinLimit(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Salary struct { Id int64 @@ -62,47 +62,27 @@ func TestJoinLimit(t *testing.T) { assert.NoError(t, err) } -func assertSync(t *testing.T, beans ...interface{}) { - for _, bean := range beans { - t.Run(testEngine.TableName(bean, true), func(t *testing.T) { - assert.NoError(t, testEngine.DropTables(bean)) - assert.NoError(t, testEngine.Sync2(bean)) - }) - } -} - func TestWhere(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("(id) > ?", 2).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err := testEngine.Where("id > ?", 2).Find(&users) + assert.NoError(t, err) - err = testEngine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) + assert.NoError(t, err) } func TestFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.Find(&users) assert.NoError(t, err) - for _, user := range users { - fmt.Println(user) - } users2 := make([]Userinfo, 0) var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) @@ -111,17 +91,13 @@ func TestFind(t *testing.T) { } func TestFind2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) users := make([]*Userinfo, 0) assertSync(t, new(Userinfo)) err := testEngine.Find(&users) assert.NoError(t, err) - - for _, user := range users { - fmt.Println(user) - } } type Team struct { @@ -140,7 +116,7 @@ func (TeamUser) TableName() string { func TestFind3(t *testing.T) { var teamUser = new(TeamUser) - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.Sync2(new(Team), teamUser) assert.NoError(t, err) @@ -194,37 +170,45 @@ func TestFind3(t *testing.T) { } func TestFindMap(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) + cnt, err := testEngine.Insert(&Userinfo{ + Username: "lunny", + Departname: "depart1", + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + users := make(map[int64]Userinfo) - err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - fmt.Println(user) - } + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.True(t, users[1].IsMan) + + users = make(map[int64]Userinfo) + err = testEngine.Cols("username, departname").Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.False(t, users[1].IsMan) } func TestFindMap2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make(map[int64]*Userinfo) err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for id, user := range users { - fmt.Println(id, user) - } + assert.NoError(t, err) } func TestDistinct(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.Insert(&Userinfo{ @@ -238,8 +222,6 @@ func TestDistinct(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(users)) - fmt.Println(users) - type Depart struct { Departname string } @@ -247,31 +229,24 @@ func TestDistinct(t *testing.T) { users2 := make([]Depart, 0) err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) assert.NoError(t, err) - if len(users2) != 1 { - fmt.Println(len(users2)) - t.Error(err) - panic(errors.New("should be one record")) - } - fmt.Println(users2) + assert.EqualValues(t, 1, len(users2)) } func TestOrder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.OrderBy("id desc").Find(&users) assert.NoError(t, err) - fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("id", "username").Desc("height").Find(&users2) assert.NoError(t, err) - fmt.Println(users2) } func TestGroupBy(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) @@ -280,207 +255,151 @@ func TestGroupBy(t *testing.T) { } func TestHaving(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) assert.NoError(t, err) - fmt.Println(users) - - /*users = make([]Userinfo, 0) - err = testEngine.Cols("id, username").GroupBy("username").Having("username='xlw'").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users)*/ } func TestOrderSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + assert.NoError(t, PrepareEngine()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("(id) desc").Find(&users) + err := testEngine.OrderBy("id desc").Find(&users) assert.NoError(t, err) - fmt.Println(users) users2 := make([]Userinfo, 0) - err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) + err = testEngine.Asc("id", "Username").Desc("Height").Find(&users2) assert.NoError(t, err) - fmt.Println(users2) } func TestHavingSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + assert.NoError(t, PrepareEngine()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.GroupBy("`Username`").Having("`Username`='xlw'").Find(&users) - if err != nil { - t.Fatal(err) - } - fmt.Println(users) + assert.NoError(t, err) } func TestFindInts(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var idsInt64 []int64 err := testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt64) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt64) + assert.NoError(t, err) var idsInt32 []int32 err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt32) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt32) + assert.NoError(t, err) var idsInt []int err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt) + assert.NoError(t, err) var idsUint []uint err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsUint) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsUint) + assert.NoError(t, err) type MyInt int var idsMyInt []MyInt err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsMyInt) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsMyInt) + assert.NoError(t, err) } func TestFindStrings(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsString []string err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsString) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsString) + assert.NoError(t, err) } func TestFindMyString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsMyString []MyString err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsMyString) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsMyString) + assert.NoError(t, err) } func TestFindInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsInterface []interface{} err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsInterface) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInterface) + assert.NoError(t, err) } func TestFindSliceBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids [][][]byte err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindSlicePtrString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids [][]*string err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindMapBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids []map[string][]byte err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindMapPtrString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids []map[string]*string err := testEngine.Table(userinfo).Desc("id").Find(&ids) assert.NoError(t, err) - for _, record := range ids { - fmt.Println(record) - } } func TestFindBit(t *testing.T) { @@ -489,7 +408,7 @@ func TestFindBit(t *testing.T) { Msg bool `xorm:"bit"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindBitStruct)) cnt, err := testEngine.Insert([]FindBitStruct{ @@ -517,7 +436,7 @@ func TestFindMark(t *testing.T) { MarkA string `xorm:"VARCHAR(1)"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Mark)) cnt, err := testEngine.Insert([]Mark{ @@ -548,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) { Msg bool `xorm:"bit"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindAndCountStruct)) cnt, err := testEngine.Insert([]FindAndCountStruct{ @@ -565,6 +484,12 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 2, cnt) var results = make([]FindAndCountStruct, 0, 2) + cnt, err = testEngine.Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + + results = make([]FindAndCountStruct, 0, 2) cnt, err = testEngine.FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) @@ -591,10 +516,96 @@ func TestFindAndCountOneFunc(t *testing.T) { results = make([]FindAndCountStruct, 0, 1) cnt, err = testEngine.Where("msg = ?", true).Desc("id"). - Limit(1).FindAndCount(&results) + Limit(1).Cols("content").FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) + + ids := make([]int64, 0, 2) + tableName := testEngine.GetTableMapper().Obj2Table("FindAndCountStruct") + cnt, err = testEngine.Table(tableName).Limit(1).Cols("id").FindAndCount(&ids) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ids)) + assert.EqualValues(t, 2, cnt) +} + +func TestFindAndCountOneFuncWithDeleted(t *testing.T) { + type CommentWithDeleted struct { + Id int `xorm:"pk autoincr"` + DeletedAt int64 `xorm:"deleted notnull default(0) index"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(CommentWithDeleted)) + + var comments []CommentWithDeleted + cnt, err := testEngine.FindAndCount(&comments) + assert.NoError(t, err) + assert.EqualValues(t, 0, cnt) +} + +func TestFindAndCount2(t *testing.T) { + // User + type TestFindAndCountUser struct { + Id int64 `xorm:"bigint(11) pk autoincr"` + Name string `xorm:"'name'"` + } + + // Hotel + type TestFindAndCountHotel struct { + Id int64 `xorm:"bigint(11) pk autoincr"` + Name string `xorm:"'name'"` + Code string `xorm:"'code'"` + Region string `xorm:"'region'"` + CreateBy *TestFindAndCountUser `xorm:"'create_by'"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel)) + + var u = TestFindAndCountUser{ + Name: "myname", + } + cnt, err := testEngine.Insert(&u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var hotel = TestFindAndCountHotel{ + Name: "myhotel", + Code: "111", + Region: "222", + CreateBy: &u, + } + cnt, err = testEngine.Insert(&hotel) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels := make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Alias("t"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels = make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Table(new(TestFindAndCountHotel)). + Alias("t"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels = make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Table(new(TestFindAndCountHotel)). + Alias("t"). + Where("t.region like '6501%'"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 0, cnt) } type FindMapDevice struct { @@ -607,7 +618,7 @@ func (device *FindMapDevice) TableName() string { } func TestFindMapStringId(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindMapDevice)) cnt, err := testEngine.Insert(&FindMapDevice{ @@ -678,7 +689,7 @@ func TestFindExtends(t *testing.T) { FindExtendsB `xorm:"extends"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindExtendsA)) cnt, err := testEngine.Insert(&FindExtendsA{ @@ -713,7 +724,7 @@ func TestFindExtends3(t *testing.T) { FindExtendsBB `xorm:"extends"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindExtendsAA)) cnt, err := testEngine.Insert(&FindExtendsAA{ @@ -749,7 +760,7 @@ func TestFindCacheLimit(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InviteCode)) cnt, err := testEngine.Insert(&InviteCode{ @@ -790,8 +801,12 @@ func TestFindJoin(t *testing.T) { DeviceId int64 } - assert.NoError(t, prepareEngine()) - assertSync(t, new(SceneItem), new(DeviceUserPrivrels)) + type Order struct { + Id int64 + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(SceneItem), new(DeviceUserPrivrels), new(Order)) var scenes []SceneItem err := testEngine.Join("LEFT OUTER", "device_user_privrels", "device_user_privrels.device_id=scene_item.device_id"). @@ -802,6 +817,10 @@ func TestFindJoin(t *testing.T) { err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "device_user_privrels.device_id=scene_item.device_id"). Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) assert.NoError(t, err) + + scenes = make([]SceneItem, 0) + err = testEngine.Join("INNER", "order", "`scene_item`.device_id=`order`.id").Find(&scenes) + assert.NoError(t, err) } func TestJoinFindLimit(t *testing.T) { @@ -816,7 +835,7 @@ func TestJoinFindLimit(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(JoinFindLimit1), new(JoinFindLimit2)) var finds []JoinFindLimit1 @@ -824,3 +843,70 @@ func TestJoinFindLimit(t *testing.T) { Limit(10, 10).Find(&finds) assert.NoError(t, err) } + +func TestMoreExtends(t *testing.T) { + type MoreExtendsUsers struct { + ID int64 `xorm:"id autoincr pk" json:"id"` + Name string `xorm:"name not null" json:"name"` + CreatedAt time.Time `xorm:"created not null" json:"created_at"` + UpdatedAt time.Time `xorm:"updated not null" json:"updated_at"` + DeletedAt time.Time `xorm:"deleted" json:"deleted_at"` + } + + type MoreExtendsBooks struct { + ID int64 `xorm:"id autoincr pk" json:"id"` + Name string `xorm:"name not null" json:"name"` + UserID int64 `xorm:"user_id not null" json:"user_id"` + CreatedAt time.Time `xorm:"created not null" json:"created_at"` + UpdatedAt time.Time `xorm:"updated not null" json:"updated_at"` + DeletedAt time.Time `xorm:"deleted" json:"deleted_at"` + } + + type MoreExtendsBooksExtend struct { + MoreExtendsBooks `xorm:"extends"` + Users MoreExtendsUsers `xorm:"extends" json:"users"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MoreExtendsUsers), new(MoreExtendsBooks)) + + var books []MoreExtendsBooksExtend + err := testEngine.Table("more_extends_books").Select("more_extends_books.*, more_extends_users.*"). + Join("INNER", "more_extends_users", "more_extends_books.user_id = more_extends_users.id"). + Where("more_extends_books.name LIKE ?", "abc"). + Limit(10, 10). + Find(&books) + assert.NoError(t, err) + + books = make([]MoreExtendsBooksExtend, 0, len(books)) + err = testEngine.Table("more_extends_books"). + Alias("m"). + Select("m.*, more_extends_users.*"). + Join("INNER", "more_extends_users", "m.user_id = more_extends_users.id"). + Where("m.name LIKE ?", "abc"). + Limit(10, 10). + Find(&books) + assert.NoError(t, err) +} + +func TestDistinctAndCols(t *testing.T) { + type DistinctAndCols struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(DistinctAndCols)) + + cnt, err := testEngine.Insert(&DistinctAndCols{ + Name: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var names []string + err = testEngine.Table("distinct_and_cols").Cols("name").Distinct("name").Find(&names) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(names)) + assert.EqualValues(t, "test", names[0]) +} diff --git a/session_get_test.go b/integrations/session_get_test.go similarity index 88% rename from session_get_test.go rename to integrations/session_get_test.go index cd5088e..35e66fb 100644 --- a/session_get_test.go +++ b/integrations/session_get_test.go @@ -2,20 +2,73 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" "fmt" + "strconv" "testing" "time" + "github.com/xormplus/xorm/contexts" + "github.com/xormplus/xorm/schemas" + "github.com/stretchr/testify/assert" - "github.com/xormplus/core" ) +func convertInt(v interface{}) (int64, error) { + switch v.(type) { + case int: + return int64(v.(int)), nil + case int8: + return int64(v.(int8)), nil + case int16: + return int64(v.(int16)), nil + case int32: + return int64(v.(int32)), nil + case int64: + return v.(int64), nil + case []byte: + i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) + if err != nil { + return 0, err + } + return i, nil + case string: + i, err := strconv.ParseInt(v.(string), 10, 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) +} + +func convertFloat(v interface{}) (float64, error) { + switch v.(type) { + case float32: + return float64(v.(float32)), nil + case float64: + return v.(float64), nil + case string: + i, err := strconv.ParseFloat(v.(string), 64) + if err != nil { + return 0, err + } + return i, nil + case []byte: + i, err := strconv.ParseFloat(string(v.([]byte)), 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) +} + func TestGetVar(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar struct { Id int64 `xorm:"autoincr pk"` @@ -153,7 +206,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) } else { has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) @@ -178,7 +231,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", valuesString["money"]) // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.Dialect().DriverName() != "mymysql" { + if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) @@ -220,7 +273,7 @@ func TestGetVar(t *testing.T) { } func TestGetStruct(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoGet struct { Uid int `xorm:"pk autoincr"` @@ -233,7 +286,7 @@ func TestGetStruct(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") @@ -242,7 +295,7 @@ func TestGetStruct(t *testing.T) { cnt, err := session.Insert(&UserinfoGet{Uid: 2}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -275,7 +328,7 @@ func TestGetStruct(t *testing.T) { } func TestGetSlice(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoSlice struct { Uid int `xorm:"pk autoincr"` @@ -291,7 +344,7 @@ func TestGetSlice(t *testing.T) { } func TestGetError(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetError struct { Uid int `xorm:"pk autoincr"` @@ -311,7 +364,7 @@ func TestGetError(t *testing.T) { } func TestJSONString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type JsonString struct { Id int64 @@ -334,17 +387,17 @@ func TestJSONString(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 1, js.Id) - assert.EqualValues(t, `["1","2"]`, js.Content) + assert.True(t, `["1","2"]` == js.Content || `["1", "2"]` == js.Content) var jss []JsonString err = testEngine.Table("json_json").Find(&jss) assert.NoError(t, err) assert.EqualValues(t, 1, len(jss)) - assert.EqualValues(t, `["1","2"]`, jss[0].Content) + assert.True(t, `["1","2"]` == jss[0].Content || `["1", "2"]` == jss[0].Content) } func TestGetActionMapping(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ActionMapping struct { ActionId string `xorm:"pk"` @@ -381,7 +434,7 @@ func TestGetStructId(t *testing.T) { Id int64 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetStruct)) _, err := testEngine.Insert(&TestGetStruct{}) @@ -408,7 +461,7 @@ func TestContextGet(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(ContextGetStruct)) _, err := testEngine.Insert(&ContextGetStruct{Name: "1"}) @@ -417,7 +470,7 @@ func TestContextGet(t *testing.T) { sess := testEngine.NewSession() defer sess.Close() - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2) @@ -446,13 +499,13 @@ func TestContextGet2(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(ContextGetStruct2)) _, err := testEngine.Insert(&ContextGetStruct2{Name: "1"}) assert.NoError(t, err) - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct2 has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2) @@ -485,7 +538,7 @@ func (MyGetCustomTableImpletation) TableName() string { } func TestGetCustomTableInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Table(getCustomTableName).Sync2(new(MyGetCustomTableImpletation))) exist, err := testEngine.IsTableExist(getCustomTableName) @@ -510,7 +563,7 @@ func TestGetNullVar(t *testing.T) { Age int } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetNullVarStruct)) affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") @@ -595,7 +648,7 @@ func TestCustomTypes(t *testing.T) { Age MyInt } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestCustomizeStruct)) var s = TestCustomizeStruct{ @@ -626,7 +679,7 @@ func TestGetViaMapCond(t *testing.T) { Index int } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(GetViaMapCond)) var ( diff --git a/session_insert_test.go b/integrations/session_insert_test.go similarity index 72% rename from session_insert_test.go rename to integrations/session_insert_test.go index 7b1aca0..9deb04a 100644 --- a/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -2,20 +2,21 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" "fmt" "reflect" "testing" "time" + "github.com/xormplus/xorm" + "github.com/stretchr/testify/assert" ) func TestInsertOne(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Test struct { Id int64 `xorm:"autoincr pk"` @@ -30,25 +31,9 @@ func TestInsertOne(t *testing.T) { assert.NoError(t, err) } -func TestInsertOne2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type Test2 struct { - Id int64 `xorm:"autoincr pk"` - Msg string `xorm:"varchar(255)"` - Created time.Time `xorm:"datetime created"` - } - - assert.NoError(t, testEngine.Sync2(new(Test2))) - - data := Test2{Msg: "hi"} - _, err := testEngine.InsertOne(data) - assert.NoError(t, err) -} - func TestInsertMulti(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` Name string `xorm:"varchar(255)"` @@ -123,7 +108,7 @@ func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) er } func TestInsertOneIfPkIsPoint(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestPoint struct { Id *int64 `xorm:"autoincr pk notnull 'id'"` @@ -139,7 +124,7 @@ func TestInsertOneIfPkIsPoint(t *testing.T) { } func TestInsertOneIfPkIsPointRename(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ID *int64 type TestPoint2 struct { Id ID `xorm:"autoincr pk notnull 'id'"` @@ -155,7 +140,7 @@ func TestInsertOneIfPkIsPointRename(t *testing.T) { } func TestInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), @@ -170,32 +155,19 @@ func TestInsert(t *testing.T) { // Username is unique, so this should return error assert.Error(t, err, "insert should fail but no error returned") assert.EqualValues(t, 0, cnt, "insert not returned 1") - if err == nil { - panic("should return err") - } } func TestInsertAutoIncr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) // auto increment insert user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} cnt, err := testEngine.Insert(&user) - fmt.Println(user.Uid) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - } - if user.Uid <= 0 { - t.Error(errors.New("not return id error")) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.Greater(t, user.Uid, int64(0)) } type DefaultInsert struct { @@ -207,7 +179,7 @@ type DefaultInsert struct { } func TestInsertDefault(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(DefaultInsert) err := testEngine.Sync2(di) @@ -217,28 +189,12 @@ func TestInsertDefault(t *testing.T) { _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) + has, err := testEngine.Desc("id").Get(di) assert.NoError(t, err) - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } - if di.Status != -1 { - err = errors.New("inserted error data") - t.Error(err) - panic(err) - } - if di2.Updated.Unix() != di.Updated.Unix() { - err = errors.New("updated should equal") - t.Error(err, di.Updated, di2.Updated) - panic(err) - } - if di2.Created.Unix() != di.Created.Unix() { - err = errors.New("created should equal") - t.Error(err, di.Created, di2.Created) - panic(err) - } + assert.True(t, has) + assert.EqualValues(t, -1, di.Status) + assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix()) + assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) } type DefaultInsert2 struct { @@ -249,57 +205,24 @@ type DefaultInsert2 struct { } func TestInsertDefault2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(DefaultInsert2) err := testEngine.Sync2(di) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) var di2 = DefaultInsert2{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2) - if err != nil { - t.Error(err) - } - - has, err := testEngine.Desc("(id)").Get(di) - if err != nil { - t.Error(err) - } - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } - - has, err = testEngine.NoAutoCondition().Desc("(id)").Get(&di2) - if err != nil { - t.Error(err) - } - - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } + assert.NoError(t, err) - if *di != di2 { - err = fmt.Errorf("%v is not equal to %v", di, di2) - t.Error(err) - panic(err) - } + has, err := testEngine.Desc("id").Get(di) + assert.NoError(t, err) + assert.True(t, has) - /*if di2.Updated.Unix() != di.Updated.Unix() { - err = errors.New("updated should equal") - t.Error(err, di.Updated, di2.Updated) - panic(err) - } - if di2.Created.Unix() != di.Created.Unix() { - err = errors.New("created should equal") - t.Error(err, di.Created, di2.Created) - panic(err) - }*/ + has, err = testEngine.NoAutoCondition().Desc("id").Get(&di2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, *di, di2) } type CreatedInsert struct { @@ -333,147 +256,91 @@ type CreatedInsert6 struct { } func TestInsertCreated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(CreatedInsert) err := testEngine.Sync2(di) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci := &CreatedInsert{} _, err = testEngine.Insert(ci) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci.Created.Unix() != di.Created.Unix() { - t.Fatal("should equal:", ci, di) - } - fmt.Println("ci:", ci, "di:", di) + has, err := testEngine.Desc("id").Get(di) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci.Created.Unix(), di.Created.Unix()) di2 := new(CreatedInsert2) err = testEngine.Sync2(di2) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci2 := &CreatedInsert2{} _, err = testEngine.Insert(ci2) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di2) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci2.Created != di2.Created { - t.Fatal("should equal:", ci2, di2) - } - fmt.Println("ci2:", ci2, "di2:", di2) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci2.Created, di2.Created) di3 := new(CreatedInsert3) err = testEngine.Sync2(di3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci3 := &CreatedInsert3{} _, err = testEngine.Insert(ci3) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di3) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci3.Created != di3.Created { - t.Fatal("should equal:", ci3, di3) - } - fmt.Println("ci3:", ci3, "di3:", di3) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci3.Created, di3.Created) di4 := new(CreatedInsert4) err = testEngine.Sync2(di4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci4 := &CreatedInsert4{} _, err = testEngine.Insert(ci4) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di4) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci4.Created != di4.Created { - t.Fatal("should equal:", ci4, di4) - } - fmt.Println("ci4:", ci4, "di4:", di4) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci4.Created, di4.Created) di5 := new(CreatedInsert5) err = testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci5 := &CreatedInsert5{} _, err = testEngine.Insert(ci5) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci5.Created.Unix() != di5.Created.Unix() { - t.Fatal("should equal:", ci5, di5) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci5.Created.Unix(), di5.Created.Unix()) di6 := new(CreatedInsert6) err = testEngine.Sync2(di6) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + oldTime := time.Now().Add(-time.Hour) ci6 := &CreatedInsert6{Created: oldTime} _, err = testEngine.Insert(ci6) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - has, err = testEngine.Desc("(id)").Get(di6) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci6.Created.Unix() != di6.Created.Unix() { - t.Fatal("should equal:", ci6, di6) - } - fmt.Println("ci6:", ci6, "di6:", di6) + has, err = testEngine.Desc("id").Get(di6) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci6.Created.Unix(), di6.Created.Unix()) } -type JsonTime time.Time +type JSONTime time.Time -func (j JsonTime) format() string { +func (j JSONTime) format() string { t := time.Time(j) if t.IsZero() { return "" @@ -482,11 +349,11 @@ func (j JsonTime) format() string { return t.Format("2006-01-02") } -func (j JsonTime) MarshalText() ([]byte, error) { +func (j JSONTime) MarshalText() ([]byte, error) { return []byte(j.format()), nil } -func (j JsonTime) MarshalJSON() ([]byte, error) { +func (j JSONTime) MarshalJSON() ([]byte, error) { return []byte(`"` + j.format() + `"`), nil } @@ -494,66 +361,55 @@ func TestDefaultTime3(t *testing.T) { type PrepareTask struct { Id int `xorm:"not null pk autoincr INT(11)" json:"id"` // ... - StartTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP index" json:"start_time"` - EndTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP" json:"end_time"` + StartTime JSONTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP index" json:"start_time"` + EndTime JSONTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP" json:"end_time"` Cuser string `xorm:"not null default '' VARCHAR(64) index" json:"cuser"` Muser string `xorm:"not null default '' VARCHAR(64)" json:"muser"` - Ctime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP created" json:"ctime"` - Mtime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP updated" json:"mtime"` + Ctime JSONTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP created" json:"ctime"` + Mtime JSONTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP updated" json:"mtime"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(PrepareTask)) prepareTask := &PrepareTask{ - StartTime: JsonTime(time.Now()), + StartTime: JSONTime(time.Now()), Cuser: "userId", Muser: "userId", } - cnt, err := testEngine.Omit("end_time").InsertOne(prepareTask) + cnt, err := testEngine.Omit("end_time").Insert(prepareTask) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } -type MyJsonTime struct { +type MyJSONTime struct { Id int64 `json:"id"` - Created JsonTime `xorm:"created" json:"created_at"` + Created JSONTime `xorm:"created" json:"created_at"` } func TestCreatedJsonTime(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) - di5 := new(MyJsonTime) + di5 := new(MyJSONTime) err := testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } - ci5 := &MyJsonTime{} + assert.NoError(t, err) + + ci5 := &MyJSONTime{} _, err = testEngine.Insert(ci5) - if err != nil { - t.Fatal(err) - } - has, err := testEngine.Desc("(id)").Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if time.Time(ci5.Created).Unix() != time.Time(di5.Created).Unix() { - t.Fatal("should equal:", time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) - var dis = make([]MyJsonTime, 0) + has, err := testEngine.Desc("id").Get(di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) + + var dis = make([]MyJSONTime, 0) err = testEngine.Find(&dis) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) } func TestInsertMulti2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) @@ -564,18 +420,41 @@ func TestInsertMulti2(t *testing.T) { {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } cnt, err := testEngine.Insert(&users) - if err != nil { - t.Error(err) - panic(err) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) + + users2 := []*Userinfo{ + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + cnt, err = testEngine.Insert(&users2) + assert.NoError(t, err) + assert.EqualValues(t, len(users2), cnt) +} + +func TestInsertMulti2Interface(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(Userinfo)) + + users := []interface{}{ + Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + Userinfo{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + Userinfo{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } - if cnt != int64(len(users)) { - err = errors.New("insert not returned 1") + + cnt, err := testEngine.Insert(&users) + if err != nil { t.Error(err) panic(err) - return } + assert.EqualValues(t, len(users), cnt) - users2 := []*Userinfo{ + users2 := []interface{}{ &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, @@ -583,20 +462,12 @@ func TestInsertMulti2(t *testing.T) { } cnt, err = testEngine.Insert(&users2) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != int64(len(users2)) { - err = errors.New(fmt.Sprintf("insert not returned %v", len(users2))) - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, len(users2), cnt) } func TestInsertTwoTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo), new(Userdetail)) @@ -604,32 +475,14 @@ func TestInsertTwoTable(t *testing.T) { userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} cnt, err := testEngine.Insert(&userinfo, &userdetail) - if err != nil { - t.Error(err) - panic(err) - } - - if userinfo.Uid <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } - - if userdetail.Id <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } - - if cnt != 2 { - err = errors.New("insert not returned 2") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.Greater(t, userinfo.Uid, int64(0)) + assert.Greater(t, userdetail.Id, int64(0)) + assert.EqualValues(t, 2, cnt) } func TestInsertCreatedInt64(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestCreatedInt64 struct { Id int64 `xorm:"autoincr pk"` @@ -661,7 +514,7 @@ func (MyUserinfo) TableName() string { } func TestInsertMulti3(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) testEngine.ShowSQL(true) assertSync(t, new(MyUserinfo)) @@ -677,10 +530,10 @@ func TestInsertMulti3(t *testing.T) { assert.EqualValues(t, len(users), cnt) users2 := []*MyUserinfo{ - &MyUserinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - &MyUserinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } cnt, err = testEngine.Insert(&users2) @@ -705,7 +558,7 @@ func (MyUserinfo2) TableName() string { } func TestInsertMulti4(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) testEngine.ShowSQL(false) assertSync(t, new(MyUserinfo2)) @@ -722,10 +575,10 @@ func TestInsertMulti4(t *testing.T) { assert.EqualValues(t, len(users), cnt) users2 := []*MyUserinfo2{ - &MyUserinfo2{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } cnt, err = testEngine.Insert(&users2) @@ -751,7 +604,7 @@ func TestAnonymousStruct(t *testing.T) { } `json:"ext" xorm:"'EXT' json notnull"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(PlainFoo)) _, err := testEngine.Insert(&PlainFoo{ @@ -780,7 +633,7 @@ func TestInsertMap(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertMap)) cnt, err := testEngine.Table(new(InsertMap)).Insert(map[string]interface{}{ @@ -865,7 +718,7 @@ func TestInsertWhere(t *testing.T) { IsTrue bool } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertWhere)) var i = InsertWhere{ @@ -959,6 +812,64 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, 5, j5.Index) } +func TestInsertExpr2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertExprsRelease struct { + Id int64 + RepoId int + IsTag bool + IsDraft bool + NumCommits int + Sha1 string + } + + assertSync(t, new(InsertExprsRelease)) + + var ie = InsertExprsRelease{ + RepoId: 1, + IsTag: true, + } + inserted, err := testEngine. + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Insert(&ie) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var ie2 InsertExprsRelease + has, err := testEngine.ID(ie.Id).Get(&ie2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, true, ie2.IsDraft) + assert.EqualValues(t, "", ie2.Sha1) + assert.EqualValues(t, 0, ie2.NumCommits) + assert.EqualValues(t, 1, ie2.RepoId) + assert.EqualValues(t, true, ie2.IsTag) + + inserted, err = testEngine.Table(new(InsertExprsRelease)). + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Insert(map[string]interface{}{ + "repo_id": 1, + "is_tag": true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var ie3 InsertExprsRelease + has, err = testEngine.ID(ie.Id + 1).Get(&ie3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, true, ie3.IsDraft) + assert.EqualValues(t, "", ie3.Sha1) + assert.EqualValues(t, 0, ie3.NumCommits) + assert.EqualValues(t, 1, ie3.RepoId) + assert.EqualValues(t, true, ie3.IsTag) +} + type NightlyRate struct { ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"` } @@ -968,7 +879,7 @@ func (NightlyRate) TableName() string { } func TestMultipleInsertTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) tableName := `prd_nightly_rate_16` assert.NoError(t, testEngine.Table(tableName).Sync2(new(NightlyRate))) @@ -999,7 +910,7 @@ func TestMultipleInsertTableName(t *testing.T) { } func TestInsertMultiWithOmit(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestMultiOmit struct { Id int64 `xorm:"int(11) pk"` @@ -1042,7 +953,7 @@ func TestInsertMultiWithOmit(t *testing.T) { } func TestInsertTwice(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type InsertStructA struct { FieldA int @@ -1056,7 +967,7 @@ func TestInsertTwice(t *testing.T) { var sliceA []InsertStructA // sliceA is empty sliceB := []InsertStructB{ - InsertStructB{ + { FieldB: 1, }, } @@ -1068,7 +979,7 @@ func TestInsertTwice(t *testing.T) { assert.NoError(t, err) _, err = ssn.Insert(sliceA) - assert.EqualValues(t, ErrNoElementsOnSlice, err) + assert.EqualValues(t, xorm.ErrNoElementsOnSlice, err) _, err = ssn.Insert(sliceB) assert.NoError(t, err) diff --git a/session_iterate_test.go b/integrations/session_iterate_test.go similarity index 96% rename from session_iterate_test.go rename to integrations/session_iterate_test.go index bb0c59c..564f457 100644 --- a/session_iterate_test.go +++ b/integrations/session_iterate_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" @@ -11,7 +11,7 @@ import ( ) func TestIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserIterate struct { Id int64 @@ -39,7 +39,7 @@ func TestIterate(t *testing.T) { } func TestBufferIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserBufferIterate struct { Id int64 diff --git a/integrations/session_pk_test.go b/integrations/session_pk_test.go new file mode 100644 index 0000000..8eff58d --- /dev/null +++ b/integrations/session_pk_test.go @@ -0,0 +1,673 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/xormplus/xorm/schemas" +) + +type IntId struct { + Id int `xorm:"pk autoincr"` + Name string +} + +type Int16Id struct { + Id int16 `xorm:"pk autoincr"` + Name string +} + +type Int32Id struct { + Id int32 `xorm:"pk autoincr"` + Name string +} + +type UintId struct { + Id uint `xorm:"pk autoincr"` + Name string +} + +type Uint16Id struct { + Id uint16 `xorm:"pk autoincr"` + Name string +} + +type Uint32Id struct { + Id uint32 `xorm:"pk autoincr"` + Name string +} + +type Uint64Id struct { + Id uint64 `xorm:"pk autoincr"` + Name string +} + +type StringPK struct { + Id string `xorm:"pk notnull"` + Name string +} + +type ID int64 +type MyIntPK struct { + ID ID `xorm:"pk autoincr"` + Name string +} + +type StrID string +type MyStringPK struct { + ID StrID `xorm:"pk notnull"` + Name string +} + +func TestIntId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&IntId{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&IntId{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&IntId{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(IntId) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]IntId, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int]IntId) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&IntId{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestInt16Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Int16Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Int16Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Int16Id{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(Int16Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Int16Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int16]Int16Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestInt32Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Int32Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Int32Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Int32Id{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(Int32Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Int32Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int32]Int32Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUintId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&UintId{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&UintId{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&UintId{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var inserts = []UintId{ + {Name: "test1"}, + {Name: "test2"}, + } + cnt, err = testEngine.Insert(&inserts) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + bean := new(UintId) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]UintId, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(beans)) + + beans2 := make(map[uint]UintId, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&UintId{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint16Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint16Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint16Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Uint16Id{Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint16Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Uint16Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[uint16]Uint16Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint32Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint32Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint32Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Uint32Id{Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint32Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Uint32Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[uint32]Uint32Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint64Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint64Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint64Id{}) + assert.NoError(t, err) + + idbean := &Uint64Id{Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint64Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.Id, idbean.Id) + + beans := make([]Uint64Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[uint64]Uint64Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.Id]) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestStringPK(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&StringPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&StringPK{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&StringPK{Id: "1-1-2", Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(StringPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]StringPK, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[string]StringPK) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type CompositeKey struct { + Id1 int64 `xorm:"id1 pk"` + Id2 int64 `xorm:"id2 pk"` + UpdateStr string +} + +func TestCompositeKey(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&CompositeKey{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&CompositeKey{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&CompositeKey{11, 22, ""}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&CompositeKey{11, 22, ""}) + assert.Error(t, err) + assert.NotEqual(t, int64(1), cnt) + + var compositeKeyVal CompositeKey + has, err := testEngine.ID(schemas.PK{11, 22}).Get(&compositeKeyVal) + assert.NoError(t, err) + assert.True(t, has) + + var compositeKeyVal2 CompositeKey + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{11, 22}).Get(&compositeKeyVal2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, compositeKeyVal, compositeKeyVal2) + + var cps = make([]CompositeKey, 0) + err = testEngine.Find(&cps) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(cps)) + assert.EqualValues(t, cps[0], compositeKeyVal) + + cnt, err = testEngine.Insert(&CompositeKey{22, 22, ""}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cps = make([]CompositeKey, 0) + err = testEngine.Find(&cps) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(cps), "should has two record") + assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") + + compositeKeyVal = CompositeKey{UpdateStr: "test1"} + cnt, err = testEngine.ID(schemas.PK{11, 22}).Update(&compositeKeyVal) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{11, 22}).Delete(&CompositeKey{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestCompositeKey2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type User struct { + UserId string `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` + } + + err := testEngine.DropTables(&User{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&User{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&User{"11", "nick", 22, 5}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&User{"11", "nick", 22, 6}) + assert.Error(t, err) + assert.NotEqual(t, 1, cnt) + + var user User + has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + user = User{NickName: "test1"} + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&User{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type MyString string +type UserPK2 struct { + UserId MyString `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` +} + +func TestCompositeKey3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&UserPK2{}) + + assert.NoError(t, err) + + err = testEngine.CreateTables(&UserPK2{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&UserPK2{"11", "nick", 22, 5}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&UserPK2{"11", "nick", 22, 6}) + assert.Error(t, err) + assert.NotEqual(t, 1, cnt) + + var user UserPK2 + has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + user = UserPK2{NickName: "test1"} + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&UserPK2{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestMyIntId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&MyIntPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyIntPK{}) + assert.NoError(t, err) + + idbean := &MyIntPK{Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(MyIntPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.ID, idbean.ID) + + var beans []MyIntPK + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[ID]MyIntPK, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.ID]) + + cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestMyStringId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&MyStringPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyStringPK{}) + assert.NoError(t, err) + + idbean := &MyStringPK{ID: "1111", Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(MyStringPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.ID, idbean.ID) + + var beans []MyStringPK + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[StrID]MyStringPK, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.ID]) + + cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestSingleAutoIncrColumn(t *testing.T) { + type Account struct { + Id int64 `xorm:"pk autoincr"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(Account)) + + _, err := testEngine.Insert(&Account{}) + assert.NoError(t, err) +} + +func TestCompositePK(t *testing.T) { + type TaskSolution struct { + UID string `xorm:"notnull pk UUID 'uid'"` + TID string `xorm:"notnull pk UUID 'tid'"` + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + assert.NoError(t, PrepareEngine()) + + tables1, err := testEngine.DBMetas() + assert.NoError(t, err) + + assertSync(t, new(TaskSolution)) + assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + + tables2, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1+len(tables1), len(tables2)) + + var table *schemas.Table + for _, t := range tables2 { + if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { + table = t + break + } + } + + assert.NotEqual(t, nil, table) + + pkCols := table.PKColumns() + assert.EqualValues(t, 2, len(pkCols)) + + names := []string{pkCols[0].Name, pkCols[1].Name} + sort.Strings(names) + assert.EqualValues(t, []string{"tid", "uid"}, names) +} + +func TestNoPKIdQueryUpdate(t *testing.T) { + type NoPKTable struct { + Username string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(NoPKTable)) + + cnt, err := testEngine.Insert(&NoPKTable{ + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res NoPKTable + has, err := testEngine.ID("test").Get(&res) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID("test").Update(&NoPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) + + type UnvalidPKTable struct { + ID int `xorm:"id"` + Username string + } + + assertSync(t, new(UnvalidPKTable)) + + cnt, err = testEngine.Insert(&UnvalidPKTable{ + ID: 1, + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res2 UnvalidPKTable + has, err = testEngine.ID(1).Get(&res2) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID(1).Update(&UnvalidPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) +} diff --git a/session_query_test.go b/integrations/session_query_test.go similarity index 91% rename from session_query_test.go rename to integrations/session_query_test.go index 59d52d7..a32f6fb 100644 --- a/session_query_test.go +++ b/integrations/session_query_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -11,13 +11,13 @@ import ( "time" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" "github.com/stretchr/testify/assert" ) func TestQueryString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar2 struct { Id int64 `xorm:"autoincr pk"` @@ -48,7 +48,7 @@ func TestQueryString(t *testing.T) { } func TestQueryString2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar3 struct { Id int64 `xorm:"autoincr pk"` @@ -108,7 +108,7 @@ func toFloat64(i interface{}) float64 { } func TestQueryInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVarInterface struct { Id int64 `xorm:"autoincr pk"` @@ -139,7 +139,7 @@ func TestQueryInterface(t *testing.T) { } func TestQueryNoParams(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type QueryNoParams struct { Id int64 `xorm:"autoincr pk"` @@ -188,7 +188,7 @@ func TestQueryNoParams(t *testing.T) { } func TestQueryStringNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar4 struct { Id int64 `xorm:"autoincr pk"` @@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -225,7 +225,7 @@ func TestQueryStringNoParam(t *testing.T) { } func TestQuerySliceStringNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar6 struct { Id int64 `xorm:"autoincr pk"` @@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) @@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) @@ -262,7 +262,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { } func TestQueryInterfaceNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar5 struct { Id int64 `xorm:"autoincr pk"` @@ -291,7 +291,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { } func TestQueryWithBuilder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type QueryWithBuilder struct { Id int64 `xorm:"autoincr pk"` @@ -336,7 +336,7 @@ func TestQueryWithBuilder(t *testing.T) { } func TestJoinWithSubQuery(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type JoinWithSubQuery1 struct { Id int64 `xorm:"autoincr pk"` diff --git a/session_raw_test.go b/integrations/session_raw_test.go similarity index 94% rename from session_raw_test.go rename to integrations/session_raw_test.go index 766206a..8b9d676 100644 --- a/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "strconv" @@ -12,7 +12,7 @@ import ( ) func TestExecAndQuery(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoQuery struct { Uid int diff --git a/session_schema_test.go b/integrations/session_schema_test.go similarity index 89% rename from session_schema_test.go rename to integrations/session_schema_test.go index 4505381..005b661 100644 --- a/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" - "os" "testing" "time" @@ -14,7 +13,7 @@ import ( ) func TestStoreEngine(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_store_engine")) @@ -27,7 +26,7 @@ func TestStoreEngine(t *testing.T) { } func TestCreateTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_user")) @@ -40,7 +39,7 @@ func TestCreateTable(t *testing.T) { } func TestCreateMultiTables(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) session := testEngine.NewSession() defer session.Close() @@ -95,7 +94,7 @@ func (s *SyncTable3) TableName() string { } func TestSyncTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SyncTable1))) @@ -120,7 +119,7 @@ func TestSyncTable(t *testing.T) { } func TestSyncTable2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Table("sync_tablex").Sync2(new(SyncTable1))) @@ -145,7 +144,7 @@ func TestSyncTable2(t *testing.T) { } func TestIsTableExist(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) exist, err := testEngine.IsTableExist(new(CustomTableName)) assert.NoError(t, err) @@ -159,7 +158,7 @@ func TestIsTableExist(t *testing.T) { } func TestIsTableEmpty(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type NumericEmpty struct { Numeric float64 `xorm:"numeric(26,2)"` @@ -202,7 +201,7 @@ func (c *CustomTableName) TableName() string { } func TestCustomTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) c := new(CustomTableName) assert.NoError(t, testEngine.DropTables(c)) @@ -210,14 +209,6 @@ func TestCustomTableName(t *testing.T) { assert.NoError(t, testEngine.CreateTables(c)) } -func TestDump(t *testing.T) { - assert.NoError(t, prepareEngine()) - - fp := testEngine.Dialect().URI().DbName + ".sql" - os.Remove(fp) - assert.NoError(t, testEngine.DumpAllToFile(fp)) -} - type IndexOrUnique struct { Id int64 Index int `xorm:"index"` @@ -229,7 +220,7 @@ type IndexOrUnique struct { } func TestIndexAndUnique(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.CreateTables(&IndexOrUnique{})) @@ -245,7 +236,7 @@ func TestIndexAndUnique(t *testing.T) { } func TestMetaInfo(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(CustomTableName), new(IndexOrUnique))) tables, err := testEngine.DBMetas() @@ -257,29 +248,23 @@ func TestMetaInfo(t *testing.T) { } func TestCharst(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables("user_charset") - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestSync2_1(t *testing.T) { type WxTest struct { - Id int `xorm:"not null pk autoincr INT(64)` + Id int `xorm:"not null pk autoincr INT(64)"` Passport_user_type int16 `xorm:"null int"` Id_delete int8 `xorm:"null int default 1"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("wx_test")) assert.NoError(t, testEngine.Sync2(new(WxTest))) @@ -296,7 +281,7 @@ func TestUnique_1(t *testing.T) { UpdatedAt time.Time `xorm:"updated"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_unique")) assert.NoError(t, testEngine.Sync2(new(UserUnique))) @@ -312,7 +297,7 @@ func TestSync2_2(t *testing.T) { UserId int64 `xorm:"index"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) var tableNames = make(map[string]bool) for i := 0; i < 10; i++ { @@ -341,7 +326,7 @@ func TestSync2_Default(t *testing.T) { Name string `xorm:"default('my_name')"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestSync2Default)) assert.NoError(t, testEngine.Sync2(new(TestSync2Default))) } diff --git a/session_stats_test.go b/integrations/session_stats_test.go similarity index 87% rename from session_stats_test.go rename to integrations/session_stats_test.go index 4f06cd0..2ad8207 100644 --- a/session_stats_test.go +++ b/integrations/session_stats_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -23,7 +23,7 @@ func TestSum(t *testing.T) { Float float32 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SumStruct))) var ( @@ -82,7 +82,7 @@ func (s SumStructWithTableName) TableName() string { } func TestSumWithTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SumStructWithTableName))) var ( @@ -132,7 +132,7 @@ func TestSumWithTableName(t *testing.T) { } func TestSumCustomColumn(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type SumStruct2 struct { Int int @@ -160,7 +160,7 @@ func TestSumCustomColumn(t *testing.T) { } func TestCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoCount struct { Departname string @@ -196,7 +196,7 @@ func TestCount(t *testing.T) { } func TestSQLCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoCount2 struct { Id int64 @@ -218,7 +218,7 @@ func TestSQLCount(t *testing.T) { } func TestCountWithOthers(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CountWithOthers struct { Id int64 @@ -252,7 +252,7 @@ func (CountWithTableName) TableName() string { } func TestWithTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(CountWithTableName)) @@ -274,3 +274,27 @@ func TestWithTableName(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, total) } + +func TestCountWithSelectCols(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.Cols("id").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} diff --git a/session_test.go b/integrations/session_test.go similarity index 70% rename from session_test.go rename to integrations/session_test.go index 343f9ba..bdf3278 100644 --- a/session_test.go +++ b/integrations/session_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" @@ -12,7 +12,7 @@ import ( ) func TestClose(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) sess1 := testEngine.NewSession() sess1.Close() @@ -31,7 +31,7 @@ func TestNullFloatStruct(t *testing.T) { Amount MyNullFloat64 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(MyNullFloatStruct))) _, err := testEngine.Insert(&MyNullFloatStruct{ @@ -43,3 +43,14 @@ func TestNullFloatStruct(t *testing.T) { }) assert.NoError(t, err) } + +func TestMustLogSQL(t *testing.T) { + assert.NoError(t, PrepareEngine()) + testEngine.ShowSQL(false) + defer testEngine.ShowSQL(true) + + assertSync(t, new(Userinfo)) + + _, err := testEngine.Table("userinfo").MustLogSQL(true).Get(new(Userinfo)) + assert.NoError(t, err) +} diff --git a/session_tx_test.go b/integrations/session_tx_test.go similarity index 84% rename from session_tx_test.go rename to integrations/session_tx_test.go index c132950..1a81639 100644 --- a/session_tx_test.go +++ b/integrations/session_tx_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -10,22 +10,20 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/xormplus/core" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/names" ) func TestTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) - counter := func() { - total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) + counter := func(t *testing.T) { + _, err := testEngine.Count(&Userinfo{}) + assert.NoError(t, err) } - counter() + counter(t) //defer counter() session := testEngine.NewSession() @@ -39,7 +37,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -50,14 +48,12 @@ func TestTransaction(t *testing.T) { } func TestCombineTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) counter := func() { total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) fmt.Printf("----now total %v records\n", total) } @@ -85,13 +81,13 @@ func TestCombineTransaction(t *testing.T) { } func TestCombineTransactionSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.SetMapper(core.SameMapper{}) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(oldMapper) }() @@ -99,9 +95,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { counter := func() { total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) fmt.Printf("----now total %v records\n", total) } @@ -119,7 +113,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) @@ -130,7 +124,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { } func TestMultipleTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MultipleTransaction struct { Id int64 diff --git a/session_update_test.go b/integrations/session_update_test.go similarity index 76% rename from session_update_test.go rename to integrations/session_update_test.go index 8a13ab3..8f43bd0 100644 --- a/session_update_test.go +++ b/integrations/session_update_test.go @@ -2,21 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/xormplus/core" + "github.com/xormplus/xorm" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/names" ) func TestUpdateMap(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateTable struct { Id int64 @@ -41,7 +42,12 @@ func TestUpdateMap(t *testing.T) { } func TestUpdateLimit(t *testing.T) { - assert.NoError(t, prepareEngine()) + if *ingoreUpdateLimit { + t.Skip() + return + } + + assert.NoError(t, PrepareEngine()) type UpdateTable2 struct { Id int64 @@ -83,7 +89,7 @@ type ForUpdate struct { Name string } -func setupForUpdate(engine EngineInterface) error { +func setupForUpdate(engine xorm.EngineInterface) error { v := new(ForUpdate) err := testEngine.DropTables(v) if err != nil { @@ -137,7 +143,7 @@ func TestForUpdate(t *testing.T) { // use lock fList := make([]ForUpdate, 0) session1.ForUpdate() - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) err = session1.Find(&fList) switch { case err != nil: @@ -158,7 +164,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f2 := new(ForUpdate) - session2.Where("(id) = ?", 1).ForUpdate() + session2.Where("id = ?", 1).ForUpdate() has, err := session2.Get(f2) // wait release lock switch { case err != nil: @@ -175,7 +181,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f3 := new(ForUpdate) - session3.Where("(id) = ?", 1) + session3.Where("id = ?", 1) has, err := session3.Get(f3) // wait release lock switch { case err != nil: @@ -193,7 +199,7 @@ func TestForUpdate(t *testing.T) { f := new(ForUpdate) f.Name = "updated by session1" - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) session1.Update(f) // release lock @@ -213,7 +219,7 @@ func TestWithIn(t *testing.T) { Test bool `xorm:"Test"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync(new(temp3))) testEngine.Insert(&[]temp3{ @@ -265,20 +271,17 @@ type Article struct { } func TestUpdateMap2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(UpdateMustCols)) _, err := testEngine.Table("update_must_cols").Where("id =?", 1).Update(map[string]interface{}{ "bool": true, }) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestUpdate1(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.Insert(&Userinfo{ @@ -287,68 +290,32 @@ func TestUpdate1(t *testing.T) { var ori Userinfo has, err := testEngine.Get(&ori) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("not exist")) - panic(errors.New("not exist")) - } + assert.NoError(t, err) + assert.True(t, has) // update by id user := Userinfo{Username: "xxx", Height: 1.2} cnt, err := testEngine.ID(ori.Uid).Update(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) condi := Condi{"username": "zzz", "departname": ""} cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) - if err != nil { - t.Error(err) - panic(err) - } - total, err := testEngine.Count(&user) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) - if cnt != total { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + total, err := testEngine.Count(&user) + assert.NoError(t, err) + assert.EqualValues(t, cnt, total) // nullable update { user := &Userinfo{Username: "not null data", Height: 180.5} _, err := testEngine.Insert(user) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) userID := user.Uid has, err := testEngine.ID(userID). @@ -358,29 +325,15 @@ func TestUpdate1(t *testing.T) { And("detail_id = ?", 0). And("is_man = ?", 0). Get(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("cannot insert properly") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has, "cannot insert properly") updatedUser := &Userinfo{Username: "null data"} cnt, err = testEngine.ID(userID). Nullable("height", "departname", "is_man", "created"). Update(updatedUser) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "update not returned 1") has, err = testEngine.ID(userID). And("username = ?", updatedUser.Username). @@ -390,73 +343,31 @@ func TestUpdate1(t *testing.T) { And("created IS NULL"). And("detail_id = ?", 0). Get(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("cannot update with null properly") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has, "cannot update with null properly") cnt, err = testEngine.ID(userID).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("delete not returned 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "delete not returned 1") } err = testEngine.StoreEngine("Innodb").Sync2(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer func() { err = testEngine.DropTables(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) }() a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - } - - if a.Id == 0 { - err = errors.New("insert returned id is 0") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, fmt.Sprintf("insert not returned 1 but %d", cnt)) + assert.Greater(t, a.Id, int32(0), "insert returned id is 0") cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) var s = "test" @@ -474,70 +385,33 @@ func TestUpdate1(t *testing.T) { col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) assert.NoError(t, err) - - if !has { - err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) - t.Error(err) - panic(err) - return - } - - if *col2 != *col3 { - err = errors.New(fmt.Sprintf("col2 should eq col3")) - t.Error(err) - panic(err) - return - } + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) { - col1 := &UpdateMustCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateMustCols{col1.Id, true, ""} boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") stringStr := testEngine.GetColumnMapper().Obj2Table("String") _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateMustCols{} has, err := testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } - - if !has { - err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) - t.Error(err) - panic(err) - return - } - - if *col2 != *col3 { - err = errors.New(fmt.Sprintf("col2 should eq col3")) - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) } } func TestUpdateIncrDecr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) col1 := &UpdateIncr{ Name: "test", @@ -600,36 +474,23 @@ type UpdatedUpdate5 struct { } func TestUpdateUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(UpdatedUpdate) err := testEngine.Sync2(di) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ci := &UpdatedUpdate{} _, err = testEngine.ID(1).Update(ci) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err := testEngine.ID(1).Get(di) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci.Updated.Unix() != di.Updated.Unix() { - t.Fatal("should equal:", ci, di) - } - fmt.Println("ci:", ci, "di:", di) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci.Updated.Unix(), di.Updated.Unix()) di2 := new(UpdatedUpdate2) err = testEngine.Sync2(di2) @@ -660,108 +521,71 @@ func TestUpdateUpdated(t *testing.T) { di3 := new(UpdatedUpdate3) err = testEngine.Sync2(di3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate3{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci3 := &UpdatedUpdate3{} _, err = testEngine.ID(1).Update(ci3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di3) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci3.Updated != di3.Updated { - t.Fatal("should equal:", ci3, di3) - } - fmt.Println("ci3:", ci3, "di3:", di3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci3.Updated, di3.Updated) di4 := new(UpdatedUpdate4) err = testEngine.Sync2(di4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate4{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ci4 := &UpdatedUpdate4{} _, err = testEngine.ID(1).Update(ci4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di4) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci4.Updated != di4.Updated { - t.Fatal("should equal:", ci4, di4) - } - fmt.Println("ci4:", ci4, "di4:", di4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci4.Updated, di4.Updated) di5 := new(UpdatedUpdate5) err = testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate5{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci5 := &UpdatedUpdate5{} _, err = testEngine.ID(1).Update(ci5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci5.Updated.Unix() != di5.Updated.Unix() { - t.Fatal("should equal:", ci5, di5) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci5.Updated.Unix(), di5.Updated.Unix()) } func TestUpdateSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldMapper := testEngine.GetTableMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) - testEngine.SetMapper(core.SameMapper{}) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) testEngine.SetMapper(oldMapper) }() @@ -806,28 +630,12 @@ func TestUpdateSameMapper(t *testing.T) { a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) assert.NoError(t, err) - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - } - - if a.Id == 0 { - err = errors.New("insert returned id is 0") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 1, cnt) + assert.Greater(t, a.Id, int32(0)) cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) assert.NoError(t, err) - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - return - } + assert.EqualValues(t, 1, cnt) col1 := &UpdateAllCols{} err = testEngine.Sync(col1) @@ -843,20 +651,8 @@ func TestUpdateSameMapper(t *testing.T) { col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) assert.NoError(t, err) - - if !has { - err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) - t.Error(err) - panic(err) - return - } - - if *col2 != *col3 { - err = errors.New(fmt.Sprintf("col2 should eq col3")) - t.Error(err) - panic(err) - return - } + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) { col1 := &UpdateMustCols{} @@ -875,94 +671,49 @@ func TestUpdateSameMapper(t *testing.T) { col3 := &UpdateMustCols{} has, err := testEngine.ID(col2.Id).Get(col3) assert.NoError(t, err) - - if !has { - err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) - t.Error(err) - panic(err) - return - } - - if *col2 != *col3 { - err = errors.New(fmt.Sprintf("col2 should eq col3")) - t.Error(err) - panic(err) - return - } + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) } { col1 := &UpdateIncr{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) cnt, err := testEngine.ID(col1.Id).Incr("`Cnt`").Update(col1) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update incr failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) newCol := new(UpdateIncr) has, err := testEngine.ID(col1.Id).Get(newCol) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("has incr failed") - t.Error(err) - panic(err) - } - if 1 != newCol.Cnt { - err = errors.New("incr failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, newCol.Cnt) } } func TestUseBool(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) cnt1, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) var fNumber int64 for _, u := range users { if u.IsMan == false { - fNumber += 1 + fNumber++ } } cnt2, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if fNumber != cnt2 { fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) /*err = errors.New("Updated number is not corrected.") @@ -971,58 +722,34 @@ func TestUseBool(t *testing.T) { } _, err = testEngine.Update(&Userinfo{IsMan: true}) - if err == nil { - err = errors.New("error condition") - t.Error(err) - panic(err) - } + assert.Error(t, err) } func TestBool(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { - if !user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } + assert.True(t, user.IsMan) } _, err = testEngine.UseBool().Update(&Userinfo{IsMan: false}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users = make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { - if user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } + assert.True(t, user.IsMan) } } func TestNoUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type NoUpdate struct { Id int64 @@ -1043,7 +770,7 @@ func TestNoUpdate(t *testing.T) { } func TestNewUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TbUserInfo struct { Id int64 `xorm:"pk autoincr unique BIGINT" json:"id"` @@ -1073,7 +800,7 @@ func TestNewUpdate(t *testing.T) { } func TestUpdateUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type PublicKeyUpdate struct { Id int64 @@ -1090,7 +817,7 @@ func TestUpdateUpdate(t *testing.T) { } func TestCreatedUpdated2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CreatedUpdatedStruct struct { Id int64 @@ -1134,7 +861,7 @@ func TestCreatedUpdated2(t *testing.T) { } func TestDeletedUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type DeletedUpdatedStruct struct { Id int64 @@ -1182,7 +909,7 @@ func TestDeletedUpdate(t *testing.T) { } func TestUpdateMapCondition(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateMapCondition struct { Id int64 @@ -1213,7 +940,7 @@ func TestUpdateMapCondition(t *testing.T) { } func TestUpdateMapContent(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateMapContent struct { Id int64 @@ -1288,7 +1015,7 @@ func TestUpdateCondiBean(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NeedUpdateBean)) cnt, err := testEngine.Insert(&NeedUpdateBean{ @@ -1338,7 +1065,7 @@ func TestWhereCondErrorWhenUpdate(t *testing.T) { RequestToken string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AuthRequestError)) _, err := testEngine.Cols("challenge_token", "request_token", "challenge_agent", "status"). @@ -1347,11 +1074,11 @@ func TestWhereCondErrorWhenUpdate(t *testing.T) { ChallengeToken: "2", }) assert.Error(t, err) - assert.EqualValues(t, ErrConditionType, err) + assert.EqualValues(t, xorm.ErrConditionType, err) } func TestUpdateDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateDeletedStruct struct { Id int64 @@ -1392,7 +1119,7 @@ func TestUpdateDeleted(t *testing.T) { } func TestUpdateExprs(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateExprs struct { Id int64 @@ -1423,7 +1150,7 @@ func TestUpdateExprs(t *testing.T) { } func TestUpdateAlias(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateAlias struct { Id int64 @@ -1453,8 +1180,53 @@ func TestUpdateAlias(t *testing.T) { assert.EqualValues(t, "lunny xiao", ue.Name) } +func TestUpdateExprs2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UpdateExprsRelease struct { + Id int64 + RepoId int + IsTag bool + IsDraft bool + NumCommits int + Sha1 string + } + + assertSync(t, new(UpdateExprsRelease)) + + var uer = UpdateExprsRelease{ + RepoId: 1, + IsTag: false, + IsDraft: false, + NumCommits: 1, + Sha1: "sha1", + } + inserted, err := testEngine.Insert(&uer) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + updated, err := testEngine. + Where("repo_id = ? AND is_tag = ?", 1, false). + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Update(new(UpdateExprsRelease)) + assert.NoError(t, err) + assert.EqualValues(t, 1, updated) + + var uer2 UpdateExprsRelease + has, err := testEngine.ID(uer.Id).Get(&uer2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, uer2.RepoId) + assert.EqualValues(t, false, uer2.IsTag) + assert.EqualValues(t, true, uer2.IsDraft) + assert.EqualValues(t, 0, uer2.NumCommits) + assert.EqualValues(t, "", uer2.Sha1) +} + func TestUpdateMap3(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateMapUser struct { Id uint64 `xorm:"PK autoincr"` @@ -1467,7 +1239,7 @@ func TestUpdateMap3(t *testing.T) { testEngine.SetColumnMapper(oldMapper) }() - mapper := core.NewPrefixMapper(core.SnakeMapper{}, "F") + mapper := names.NewPrefixMapper(names.SnakeMapper{}, "F") testEngine.SetColumnMapper(mapper) assertSync(t, new(UpdateMapUser)) @@ -1494,3 +1266,78 @@ func TestUpdateMap3(t *testing.T) { assert.Error(t, err) assert.EqualValues(t, 0, rows) } + +func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { + type TestOnlyFromDBField struct { + Id int64 `xorm:"PK"` + OnlyFromDBField string `xorm:"<-"` + OnlyToDBField string `xorm:"->"` + IngoreField string `xorm:"-"` + } + + assertGetRecord := func() *TestOnlyFromDBField { + var record TestOnlyFromDBField + has, err := testEngine.Where("id = ?", 1).Get(&record) + assert.NoError(t, err) + assert.EqualValues(t, true, has) + assert.EqualValues(t, "", record.OnlyFromDBField) + return &record + + } + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestOnlyFromDBField)) + + _, err := testEngine.Insert(&TestOnlyFromDBField{ + Id: 1, + OnlyFromDBField: "a", + OnlyToDBField: "b", + IngoreField: "c", + }) + assert.NoError(t, err) + + assertGetRecord() + + _, err = testEngine.ID(1).Update(&TestOnlyFromDBField{ + OnlyToDBField: "b", + OnlyFromDBField: "test", + }) + assert.NoError(t, err) + assertGetRecord() +} + +func TestUpdateMultiplePK(t *testing.T) { + type TestUpdateMultiplePKStruct struct { + Id string `xorm:"notnull pk" description:"唯一ID号"` + Name string `xorm:"notnull pk" description:"名称"` + Value string `xorm:"notnull varchar(4000)" description:"值"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateMultiplePKStruct)) + + test := &TestUpdateMultiplePKStruct{ + Id: "ID1", + Name: "Name1", + Value: "1", + } + _, err := testEngine.Insert(test) + assert.NoError(t, err) + + test.Value = "2" + _, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test) + assert.NoError(t, err) + + test.Value = "3" + num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test) + assert.NoError(t, err) + assert.EqualValues(t, 1, num) + + test.Value = "4" + _, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test) + assert.NoError(t, err) + + type MySlice []interface{} + test.Value = "5" + _, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test) + assert.NoError(t, err) +} diff --git a/integrations/tags_test.go b/integrations/tags_test.go new file mode 100644 index 0000000..7013864 --- /dev/null +++ b/integrations/tags_test.go @@ -0,0 +1,1329 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "fmt" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" +) + +type tempUser struct { + Id int64 + Username string +} + +type tempUser2 struct { + TempUser tempUser `xorm:"extends"` + Departname string +} + +type tempUser3 struct { + Temp *tempUser `xorm:"extends"` + Departname string +} + +type tempUser4 struct { + TempUser2 tempUser2 `xorm:"extends"` +} + +type Userinfo struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +type Userdetail struct { + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` +} + +type UserAndDetail struct { + Userinfo `xorm:"extends"` + Userdetail `xorm:"extends"` +} + +func TestExtends(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&tempUser2{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser2{}) + assert.NoError(t, err) + + tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu) + assert.NoError(t, err) + + tu2 := &tempUser2{} + _, err = testEngine.Get(tu2) + assert.NoError(t, err) + + tu3 := &tempUser2{tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) + assert.NoError(t, err) + + err = testEngine.DropTables(&tempUser4{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser4{}) + assert.NoError(t, err) + + tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} + _, err = testEngine.Insert(tu8) + assert.NoError(t, err) + + tu9 := &tempUser4{} + _, err = testEngine.Get(tu9) + assert.NoError(t, err) + assert.EqualValues(t, tu8.TempUser2.TempUser.Username, tu9.TempUser2.TempUser.Username) + assert.EqualValues(t, tu8.TempUser2.Departname, tu9.TempUser2.Departname) + + tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} + _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) + assert.NoError(t, err) + + err = testEngine.DropTables(&tempUser3{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser3{}) + assert.NoError(t, err) + + tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu4) + assert.NoError(t, err) + + tu5 := &tempUser3{} + _, err = testEngine.Get(tu5) + assert.NoError(t, err) + + assert.NotNil(t, tu5.Temp) + assert.EqualValues(t, 1, tu5.Temp.Id) + assert.EqualValues(t, "extends", tu5.Temp.Username) + assert.EqualValues(t, "dev depart", tu5.Departname) + + tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) + assert.NoError(t, err) + + users := make([]tempUser3, 0) + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users), "error get data not 1") + + assertSync(t, new(Userinfo), new(Userdetail)) + + detail := Userdetail{ + Intro: "I'm in China", + } + _, err = testEngine.Insert(&detail) + assert.NoError(t, err) + + _, err = testEngine.Insert(&Userinfo{ + Username: "lunny", + Detail: detail, + }) + assert.NoError(t, err) + + var info UserAndDetail + qt := testEngine.Quote + ui := testEngine.TableName(new(Userinfo), true) + ud := testEngine.TableName(&detail, true) + uiid := testEngine.GetColumnMapper().Obj2Table("Id") + udid := "detail_id" + sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", + qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) + b, err := testEngine.SQL(sql).NoCascade().Get(&info) + assert.NoError(t, err) + assert.True(t, b, "should has lest one record") + assert.True(t, info.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info.Userdetail.Id > 0, "all of the id should has value") + + var info2 UserAndDetail + b, err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade().Get(&info2) + assert.NoError(t, err) + assert.True(t, b) + assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value") + + var infos2 = make([]UserAndDetail, 0) + err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade(). + Find(&infos2) + assert.NoError(t, err) +} + +type MessageBase struct { + Id int64 `xorm:"int(11) pk autoincr"` + TypeId int64 `xorm:"int(11) notnull"` +} + +type Message struct { + MessageBase `xorm:"extends"` + Title string `xorm:"varchar(100) notnull"` + Content string `xorm:"text notnull"` + Uid int64 `xorm:"int(11) notnull"` + ToUid int64 `xorm:"int(11) notnull"` + CreateTime time.Time `xorm:"datetime notnull created"` +} + +type MessageUser struct { + Id int64 + Name string +} + +type MessageType struct { + Id int64 + Name string +} + +type MessageExtend3 struct { + Message `xorm:"extends"` + Sender MessageUser `xorm:"extends"` + Receiver MessageUser `xorm:"extends"` + Type MessageType `xorm:"extends"` +} + +type MessageExtend4 struct { + Message `xorm:"extends"` + MessageUser `xorm:"extends"` + MessageType `xorm:"extends"` +} + +func TestExtends2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + cnt, err := session.Insert(&msg) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]Message, 0) + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + + assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list))) + assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg)) +} + +func TestExtends3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = session.Insert(&msg) + assert.NoError(t, err) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]MessageExtend3, 0) + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(list)) + assert.EqualValues(t, list[0].Message.Id, msg.Id) + assert.EqualValues(t, list[0].Sender.Id, sender.Id) + assert.EqualValues(t, list[0].Sender.Name, sender.Name) + assert.EqualValues(t, list[0].Receiver.Id, receiver.Id) + assert.EqualValues(t, list[0].Receiver.Name, receiver.Name) + assert.EqualValues(t, list[0].Type.Id, msgtype.Id) + assert.EqualValues(t, list[0].Type.Name, msgtype.Name) +} + +func TestExtends4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = session.Insert(&msg) + assert.NoError(t, err) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]MessageExtend4, 0) + err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + assert.EqualValues(t, len(list), 1) + assert.EqualValues(t, list[0].Message.Id, msg.Id) + assert.EqualValues(t, list[0].MessageUser.Id, sender.Id) + assert.EqualValues(t, list[0].MessageUser.Name, sender.Name) + assert.EqualValues(t, list[0].MessageType.Id, msgtype.Id) + assert.EqualValues(t, list[0].MessageType.Name, msgtype.Name) +} + +type Size struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + Width float32 `json:"width" xorm:"float 'Width'"` + Height float32 `json:"height" xorm:"float 'Height'"` +} + +type Book struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + SizeOpen *Size `xorm:"extends('Open')"` + SizeClosed *Size `xorm:"extends('Closed')"` + Size *Size `xorm:"extends('')"` +} + +func TestExtends5(t *testing.T) { + assert.NoError(t, PrepareEngine()) + err := testEngine.DropTables(&Book{}, &Size{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Size{}, &Book{}) + assert.NoError(t, err) + + var sc = Size{Width: 0.2, Height: 0.4} + var so = Size{Width: 0.2, Height: 0.8} + var s = Size{Width: 0.15, Height: 1.5} + var bk1 = Book{ + SizeOpen: &so, + SizeClosed: &sc, + Size: &s, + } + var bk2 = Book{ + SizeOpen: &so, + } + var bk3 = Book{ + SizeClosed: &sc, + Size: &s, + } + var bk4 = Book{} + var bk5 = Book{Size: &s} + _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) + if err != nil { + t.Fatal(err) + } + + var books = map[int64]Book{ + bk1.ID: bk1, + bk2.ID: bk2, + bk3.ID: bk3, + bk4.ID: bk4, + bk5.ID: bk5, + } + + session := testEngine.NewSession() + defer session.Close() + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + bookTableName := quote(testEngine.TableName(mapper("Book"), true)) + sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) + + list := make([]Book, 0) + err = session. + Select(fmt.Sprintf( + "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", + quote(bookTableName), + quote("id"), + quote("Width"), + quote("ClosedWidth"), + quote("Height"), + quote("ClosedHeight"), + quote("Width"), + quote("Height"), + )). + Table(bookTableName). + Join( + "LEFT", + sizeTableName+" AS `sc`", + bookTableName+".`SizeClosed`=sc.`id`", + ). + Join( + "LEFT", + sizeTableName+" AS `s`", + bookTableName+".`Size`=s.`id`", + ). + Find(&list) + assert.NoError(t, err) + + for _, book := range list { + if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if books[book.ID].Size != nil || book.Size != nil { + if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + + if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + } + } +} + +func TestCacheTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CacheDomain struct { + Id int64 `xorm:"pk cache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil) +} + +func TestNoCacheTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type NoCacheDomain struct { + Id int64 `xorm:"pk nocache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil) +} + +type IDGonicMapper struct { + ID int64 +} + +func TestGonicMapperID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + oldMapper := testEngine.GetColumnMapper() + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(names.LintGonicMapper) + defer func() { + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDGonicMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "id_gonic_mapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { + t.Fatal(tb) + } + return + } + } + + t.Fatal("not table id_gonic_mapper") +} + +type IDSameMapper struct { + ID int64 +} + +func TestSameMapperID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + oldMapper := testEngine.GetColumnMapper() + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(names.SameMapper{}) + defer func() { + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDSameMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "IDSameMapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { + t.Fatalf("tb %s tb.PKColumns() is %d not 1, tb.PKColumns()[0].Name is %s not ID", tb.Name, len(tb.PKColumns()), tb.PKColumns()[0].Name) + } + return + } + } + t.Fatal("not table IDSameMapper") +} + +type UserCU struct { + Id int64 + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +func TestCreatedAndUpdated(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + u := new(UserCU) + err := testEngine.DropTables(u) + assert.NoError(t, err) + + err = testEngine.CreateTables(u) + assert.NoError(t, err) + + u.Name = "sss" + cnt, err := testEngine.Insert(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + u.Name = "xxx" + cnt, err = testEngine.ID(u.Id).Update(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + u.Id = 0 + u.Created = time.Now().Add(-time.Hour * 24 * 365) + u.Updated = u.Created + cnt, err = testEngine.NoAutoTime().Insert(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type StrangeName struct { + Id_t int64 `xorm:"pk autoincr"` + Name string +} + +func TestStrangeName(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(StrangeName)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(StrangeName)) + assert.NoError(t, err) + + _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) + assert.NoError(t, err) + + beans := make([]StrangeName, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) +} + +func TestCreatedUpdated(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CreatedUpdated struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created time.Time `xorm:"created"` + Created2 time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + err := testEngine.Sync2(&CreatedUpdated{}) + assert.NoError(t, err) + + c := &CreatedUpdated{Name: "test"} + _, err = testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdated) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + + assert.True(t, has) + + c2.Value-- + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +func TestCreatedUpdatedInt64(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CreatedUpdatedInt64 struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created int64 `xorm:"created"` + Created2 int64 `xorm:"created"` + Updated int64 `xorm:"updated"` + } + + assertSync(t, &CreatedUpdatedInt64{}) + + c := &CreatedUpdatedInt64{Name: "test"} + _, err := testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdatedInt64) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + assert.True(t, has) + + c2.Value-- + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +type Lowercase struct { + Id int64 + Name string + ended int64 `xorm:"-"` +} + +func TestLowerCase(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.Sync2(&Lowercase{}) + assert.NoError(t, err) + _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) + assert.NoError(t, err) + + _, err = testEngine.Insert(&Lowercase{ended: 1}) + assert.NoError(t, err) + + ls := make([]Lowercase, 0) + err = testEngine.Find(&ls) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ls)) +} + +func TestAutoIncrTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestAutoIncr1 struct { + Id int64 + } + + tb, err := testEngine.TableInfo(new(TestAutoIncr1)) + assert.NoError(t, err) + + cols := tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.True(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr2 struct { + Id int64 `xorm:"id"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr2)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr3 struct { + Id int64 `xorm:"'ID'"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr3)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "ID", cols[0].Name) + + type TestAutoIncr4 struct { + Id int64 `xorm:"pk"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr4)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) +} + +func TestTagComment(t *testing.T) { + assert.NoError(t, PrepareEngine()) + // FIXME: only support mysql + if testEngine.Dialect().URI().DBType != schemas.MYSQL { + return + } + + type TestComment1 struct { + Id int64 `xorm:"comment(主键)"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment1))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) + + assert.NoError(t, testEngine.DropTables(new(TestComment1))) + + type TestComment2 struct { + Id int64 `xorm:"comment('主键')"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment2))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) +} + +func TestTagDefault(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct struct { + Id int64 + Name string + Age int `xorm:"default(10)"` + } + + assertSync(t, new(DefaultStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("age") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "10", defaultVal) + + cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ + Name: "test", + Age: 20, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s DefaultStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 10, s.Age) + assert.EqualValues(t, "test", s.Name) +} + +func TestTagDefault2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct2 struct { + Id int64 + Name string + } + + assertSync(t, new(DefaultStruct2)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct2") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.False(t, isDefaultExist, fmt.Sprintf("default value is --%v--", defaultVal)) + assert.EqualValues(t, "", defaultVal) +} + +func TestTagDefault3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct3 struct { + Id int64 + Name string `xorm:"default('myname')"` + } + + assertSync(t, new(DefaultStruct3)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct3") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'myname'", defaultVal) +} + +func TestTagDefault4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct4 struct { + Id int64 + Created time.Time `xorm:"default(CURRENT_TIMESTAMP)"` + } + + assertSync(t, new(DefaultStruct4)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.True(t, "CURRENT_TIMESTAMP" == defaultVal || + "current_timestamp()" == defaultVal || // for cockroach + "now()" == defaultVal || + "getdate" == defaultVal, defaultVal) +} + +func TestTagDefault5(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct5 struct { + Id int64 + Created time.Time `xorm:"default('2006-01-02 15:04:05')"` + } + + assertSync(t, new(DefaultStruct5)) + table, err := testEngine.TableInfo(new(DefaultStruct5)) + assert.NoError(t, err) + + createdCol := table.GetColumn("created") + assert.NotNil(t, createdCol) + assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) + assert.False(t, createdCol.DefaultIsEmpty) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) +} + +func TestTagDefault6(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct6 struct { + Id int64 + IsMan bool `xorm:"default(true)"` + } + + assertSync(t, new(DefaultStruct6)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("is_man") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + if defaultVal == "1" { + defaultVal = "true" + } else if defaultVal == "0" { + defaultVal = "false" + } + assert.EqualValues(t, "true", defaultVal) +} + +func TestTagsDirection(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type OnlyFromDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"<- default '1'"` + } + + assertSync(t, new(OnlyFromDBStruct)) + + cnt, err := testEngine.Insert(&OnlyFromDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s OnlyFromDBStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "1", s.Uuid) + assert.EqualValues(t, "test", s.Name) + + cnt, err = testEngine.ID(1).Update(&OnlyFromDBStruct{ + Uuid: "3", + Name: "test1", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s3 OnlyFromDBStruct + has, err = testEngine.ID(1).Get(&s3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "1", s3.Uuid) + assert.EqualValues(t, "test1", s3.Name) + + type OnlyToDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"->"` + } + + assertSync(t, new(OnlyToDBStruct)) + + cnt, err = testEngine.Insert(&OnlyToDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s2 OnlyToDBStruct + has, err = testEngine.ID(1).Get(&s2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", s2.Uuid) + assert.EqualValues(t, "test", s2.Name) +} + +func TestTagTime(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TagUTCStruct struct { + Id int64 + Name string + Created time.Time `xorm:"created utc"` + } + + assertSync(t, new(TagUTCStruct)) + + assert.EqualValues(t, time.Local.String(), testEngine.GetTZLocation().String()) + + s := TagUTCStruct{ + Name: "utc", + } + cnt, err := testEngine.Insert(&s) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var u TagUTCStruct + has, err := testEngine.ID(1).Get(&u) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"), u.Created.Format("2006-01-02 15:04:05")) + + var tm string + has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), + strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) +} + +func TestTagAutoIncr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TagAutoIncr struct { + Id int64 + Name string + } + + assertSync(t, new(TagAutoIncr)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.True(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.False(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} + +func TestTagPrimarykey(t *testing.T) { + assert.NoError(t, PrepareEngine()) + type TagPrimaryKey struct { + Id int64 `xorm:"pk"` + Name string `xorm:"VARCHAR(20) pk"` + } + + assertSync(t, new(TagPrimaryKey)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.False(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.True(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} + +type VersionS struct { + Id int64 + Name string + Ver int `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion1(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionS)) + assert.NoError(t, err) + + ver := &VersionS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + assert.NoError(t, err) + assert.EqualValues(t, ver.Ver, 1) + + newVer := new(VersionS) + has, err := testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, newVer.Ver, 1) + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) + + newVer = new(VersionS) + has, err = testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) +} + +func TestVersion2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionS)) + assert.NoError(t, err) + + var vers = []VersionS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + assert.NoError(t, err) + for _, v := range vers { + assert.EqualValues(t, v.Ver, 1) + } +} + +type VersionUintS struct { + Id int64 + Name string + Ver uint `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionUintS)) + assert.NoError(t, err) + + ver := &VersionUintS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + assert.NoError(t, err) + assert.EqualValues(t, ver.Ver, 1) + + newVer := new(VersionUintS) + has, err := testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, newVer.Ver, 1) + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) + + newVer = new(VersionUintS) + has, err = testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) +} + +func TestVersion4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionUintS)) + assert.NoError(t, err) + + var vers = []VersionUintS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + assert.NoError(t, err) + for _, v := range vers { + assert.EqualValues(t, v.Ver, 1) + } +} + +func TestIndexes(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestIndexesStruct struct { + Id int64 + Name string `xorm:"index unique(s)"` + Email string `xorm:"index unique(s)"` + } + + assertSync(t, new(TestIndexesStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 3, len(tables[0].Columns())) + slice1 := []string{ + testEngine.GetColumnMapper().Obj2Table("Id"), + testEngine.GetColumnMapper().Obj2Table("Name"), + testEngine.GetColumnMapper().Obj2Table("Email"), + } + slice2 := []string{ + tables[0].Columns()[0].Name, + tables[0].Columns()[1].Name, + tables[0].Columns()[2].Name, + } + sort.Strings(slice1) + sort.Strings(slice2) + assert.EqualValues(t, slice1, slice2) + assert.EqualValues(t, 3, len(tables[0].Indexes)) +} diff --git a/xorm_test.go b/integrations/tests.go similarity index 55% rename from xorm_test.go rename to integrations/tests.go index 2089f0b..aa1731b 100644 --- a/xorm_test.go +++ b/integrations/tests.go @@ -1,28 +1,27 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. +// Copyright 2017 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" "flag" "fmt" - "log" "os" "strings" "testing" - _ "github.com/denisenkom/go-mssqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/xormplus/core" - _ "github.com/ziutek/mymysql/godrv" + "github.com/xormplus/xorm" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" ) var ( - testEngine EngineInterface + testEngine xorm.EngineInterface dbType string connString string @@ -30,13 +29,15 @@ var ( showSQL = flag.Bool("show_sql", true, "show generated SQLs") ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") mapType = flag.String("map_type", "snake", "indicate the name mapping") - cache = flag.Bool("cache", false, "if enable cache") + cacheFlag = flag.Bool("cache", false, "if enable cache") cluster = flag.Bool("cluster", false, "if this is a cluster") splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") schema = flag.String("schema", "", "specify the schema") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") - tableMapper core.IMapper - colMapper core.IMapper + ingoreUpdateLimit = flag.Bool("ignore_update_limit", false, "ignore update limit if implementation difference, only for cockroach") + quotePolicyStr = flag.String("quote", "always", "quote could be always, none, reversed") + tableMapper names.Mapper + colMapper names.Mapper ) func createEngine(dbType, connStr string) error { @@ -44,26 +45,23 @@ func createEngine(dbType, connStr string) error { var err error if !*cluster { - // create databases if not exist - var db *sql.DB - var err error - if strings.ToLower(dbType) != core.MSSQL { - db, err = sql.Open(dbType, connStr) - } else { - db, err = sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) - } - - if err != nil { - return err - } - - switch strings.ToLower(dbType) { - case core.MSSQL: + switch schemas.DBType(strings.ToLower(dbType)) { + case schemas.MSSQL: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) + if err != nil { + return err + } if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil { return fmt.Errorf("db.Exec: %v", err) } - case core.POSTGRES: - rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'")) + db.Close() + *ignoreSelectUpdate = true + case schemas.POSTGRES: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "postgres", -1)) + if err != nil { + return err + } + rows, err := db.Query("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'") if err != nil { return fmt.Errorf("db.Query: %v", err) } @@ -75,20 +73,37 @@ func createEngine(dbType, connStr string) error { } } if *schema != "" { + db.Close() + db, err = sql.Open(dbType, connStr) + if err != nil { + return err + } + defer db.Close() if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil { return fmt.Errorf("CREATE SCHEMA: %v", err) } } - case core.MYSQL: + db.Close() + *ignoreSelectUpdate = true + case schemas.MYSQL: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) + if err != nil { + return err + } if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil { return fmt.Errorf("db.Exec: %v", err) } + db.Close() + default: + *ignoreSelectUpdate = true } - db.Close() - testEngine, err = NewEngine(dbType, connStr) + testEngine, err = xorm.NewEngine(dbType, connStr) } else { - testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) + testEngine, err = xorm.NewEngineGroup(dbType, strings.Split(connStr, *splitter)) + if dbType != "mysql" && dbType != "mymysql" { + *ignoreSelectUpdate = true + } } if err != nil { return err @@ -98,22 +113,30 @@ func createEngine(dbType, connStr string) error { testEngine.SetSchema(*schema) } testEngine.ShowSQL(*showSQL) - testEngine.SetLogLevel(core.LOG_DEBUG) - if *cache { - cacher := NewLRUCacher(NewMemoryStore(), 100000) + testEngine.SetLogLevel(log.LOG_DEBUG) + if *cacheFlag { + cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 100000) testEngine.SetDefaultCacher(cacher) } if len(*mapType) > 0 { switch *mapType { case "snake": - testEngine.SetMapper(core.SnakeMapper{}) + testEngine.SetMapper(names.SnakeMapper{}) case "same": - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) case "gonic": - testEngine.SetMapper(core.LintGonicMapper) + testEngine.SetMapper(names.LintGonicMapper) } } + + if *quotePolicyStr == "none" { + testEngine.SetQuotePolicy(dialects.QuotePolicyNone) + } else if *quotePolicyStr == "reserved" { + testEngine.SetQuotePolicy(dialects.QuotePolicyReserved) + } else { + testEngine.SetQuotePolicy(dialects.QuotePolicyAlways) + } } tableMapper = testEngine.GetTableMapper() @@ -133,11 +156,11 @@ func createEngine(dbType, connStr string) error { return nil } -func prepareEngine() error { +func PrepareEngine() error { return createEngine(dbType, connString) } -func TestMain(m *testing.M) { +func MainTest(m *testing.M) { flag.Parse() dbType = *db @@ -149,7 +172,7 @@ func TestMain(m *testing.M) { } } else { if ptrConnStr == nil { - log.Fatal("you should indicate conn string") + fmt.Println("you should indicate conn string") return } connString = *ptrConnStr @@ -165,8 +188,9 @@ func TestMain(m *testing.M) { testEngine = nil fmt.Println("testing", dbType, connString) - if err := prepareEngine(); err != nil { - log.Fatal(err) + if err := PrepareEngine(); err != nil { + fmt.Println(err) + os.Exit(1) return } @@ -178,9 +202,3 @@ func TestMain(m *testing.M) { os.Exit(res) } - -func TestPing(t *testing.T) { - if err := testEngine.Ping(); err != nil { - t.Fatal(err) - } -} diff --git a/time_test.go b/integrations/time_test.go similarity index 90% rename from time_test.go rename to integrations/time_test.go index de04c46..0a8208c 100644 --- a/time_test.go +++ b/integrations/time_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -10,11 +10,17 @@ import ( "testing" "time" + "github.com/xormplus/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) +func formatTime(t time.Time) string { + return t.Format("2006-01-02 15:04:05") +} + func TestTimeUserTime(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TimeUser struct { Id string @@ -44,7 +50,7 @@ func TestTimeUserTime(t *testing.T) { } func TestTimeUserTimeDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -80,7 +86,7 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { } func TestTimeUserCreated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserCreated struct { Id string @@ -109,7 +115,7 @@ func TestTimeUserCreated(t *testing.T) { } func TestTimeUserCreatedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -144,7 +150,7 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { } func TestTimeUserUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserUpdated struct { Id string @@ -195,7 +201,7 @@ func TestTimeUserUpdated(t *testing.T) { } func TestTimeUserUpdatedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -252,7 +258,7 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { } func TestTimeUserDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserDeleted struct { Id string @@ -282,14 +288,15 @@ func TestTimeUserDeleted(t *testing.T) { assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) - assert.True(t, isTimeZero(user2.DeletedAt)) + assert.True(t, utils.IsTimeZero(user2.DeletedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + fmt.Println("user2 str", user2.CreatedAtStr, user2.UpdatedAtStr) var user3 UserDeleted cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(user3.DeletedAt)) + assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) var user4 UserDeleted has, err = testEngine.Unscoped().Get(&user4) @@ -301,7 +308,7 @@ func TestTimeUserDeleted(t *testing.T) { } func TestTimeUserDeletedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -335,14 +342,14 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) - assert.True(t, isTimeZero(user2.DeletedAt)) + assert.True(t, utils.IsTimeZero(user2.DeletedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted2 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(user3.DeletedAt)) + assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) var user4 UserDeleted2 has, err = testEngine.Unscoped().Get(&user4) @@ -353,38 +360,38 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) } -type JsonDate time.Time +type JSONDate time.Time -func (j JsonDate) MarshalJSON() ([]byte, error) { +func (j JSONDate) MarshalJSON() ([]byte, error) { if time.Time(j).IsZero() { return []byte(`""`), nil } return []byte(`"` + time.Time(j).Format("2006-01-02 15:04:05") + `"`), nil } -func (j *JsonDate) UnmarshalJSON(value []byte) error { +func (j *JSONDate) UnmarshalJSON(value []byte) error { var v = strings.TrimSpace(strings.Trim(string(value), "\"")) t, err := time.ParseInLocation("2006-01-02 15:04:05", v, time.Local) if err != nil { return err } - *j = JsonDate(t) + *j = JSONDate(t) return nil } -func (j *JsonDate) Unix() int64 { +func (j *JSONDate) Unix() int64 { return (*time.Time)(j).Unix() } func TestCustomTimeUserDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserDeleted3 struct { Id string - CreatedAt JsonDate `xorm:"created"` - UpdatedAt JsonDate `xorm:"updated"` - DeletedAt JsonDate `xorm:"deleted"` + CreatedAt JSONDate `xorm:"created"` + UpdatedAt JSONDate `xorm:"updated"` + DeletedAt JSONDate `xorm:"deleted"` } assertSync(t, new(UserDeleted3)) @@ -406,14 +413,14 @@ func TestCustomTimeUserDeleted(t *testing.T) { assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) - assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + assert.True(t, utils.IsTimeZero(time.Time(user2.DeletedAt))) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted3 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) var user4 UserDeleted3 has, err = testEngine.Unscoped().Get(&user4) @@ -425,7 +432,7 @@ func TestCustomTimeUserDeleted(t *testing.T) { } func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -435,9 +442,9 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { type UserDeleted4 struct { Id string - CreatedAt JsonDate `xorm:"created"` - UpdatedAt JsonDate `xorm:"updated"` - DeletedAt JsonDate `xorm:"deleted"` + CreatedAt JSONDate `xorm:"created"` + UpdatedAt JSONDate `xorm:"updated"` + DeletedAt JSONDate `xorm:"deleted"` } assertSync(t, new(UserDeleted4)) @@ -459,14 +466,14 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) - assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + assert.True(t, utils.IsTimeZero(time.Time(user2.DeletedAt))) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted4 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) var user4 UserDeleted4 has, err = testEngine.Unscoped().Get(&user4) @@ -478,7 +485,7 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { } func TestDeletedInt64(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type DeletedInt64Struct struct { Id int64 diff --git a/types_null_test.go b/integrations/types_null_test.go similarity index 54% rename from types_null_test.go rename to integrations/types_null_test.go index 22fc102..98bd86b 100644 --- a/types_null_test.go +++ b/integrations/types_null_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" @@ -22,7 +22,8 @@ type NullType struct { Age sql.NullInt64 Height sql.NullFloat64 IsMan sql.NullBool `xorm:"null"` - CustomStruct CustomStruct `xorm:"valchar(64) null"` + Nil driver.Valuer + CustomStruct CustomStruct `xorm:"varchar(64) null"` } type CustomStruct struct { @@ -57,90 +58,61 @@ func (m CustomStruct) Value() (driver.Value, error) { } func TestCreateNullStructTable(t *testing.T) { - assert.NoError(t, prepareEngine()) - + assert.NoError(t, PrepareEngine()) err := testEngine.CreateTables(new(NullType)) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestDropNullStructTable(t *testing.T) { - assert.NoError(t, prepareEngine()) - + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(new(NullType)) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestNullStructInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) - if true { - item := new(NullType) - _, err := testEngine.Insert(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) - if item.Id != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + item1 := new(NullType) + _, err := testEngine.Insert(item1) + assert.NoError(t, err) + assert.EqualValues(t, 1, item1.Id) + + item := NullType{ + Name: sql.NullString{String: "haolei", Valid: true}, + Age: sql.NullInt64{Int64: 34, Valid: true}, + Height: sql.NullFloat64{Float64: 1.72, Valid: true}, + IsMan: sql.NullBool{Bool: true, Valid: true}, + Nil: nil, } + _, err = testEngine.Insert(&item) + assert.NoError(t, err) + assert.EqualValues(t, 2, item.Id) - if true { + items := []NullType{} + for i := 0; i < 5; i++ { item := NullType{ - Name: sql.NullString{"haolei", true}, - Age: sql.NullInt64{34, true}, - Height: sql.NullFloat64{1.72, true}, - IsMan: sql.NullBool{true, true}, - } - _, err := testEngine.Insert(&item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) - if item.Id != 2 { - err = errors.New("insert error") - t.Error(err) - panic(err) + Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, + Age: sql.NullInt64{Int64: 30 + int64(i), Valid: true}, + Height: sql.NullFloat64{Float64: 1.5 + 1.1*float64(i), Valid: true}, + IsMan: sql.NullBool{Bool: true, Valid: true}, + CustomStruct: CustomStruct{i, i + 1, i + 2}, + Nil: nil, } + items = append(items, item) } - if true { - items := []NullType{} - - for i := 0; i < 5; i++ { - item := NullType{ - Name: sql.NullString{"haolei_" + fmt.Sprint(i+1), true}, - Age: sql.NullInt64{30 + int64(i), true}, - Height: sql.NullFloat64{1.5 + 1.1*float64(i), true}, - IsMan: sql.NullBool{true, true}, - CustomStruct: CustomStruct{i, i + 1, i + 2}, - } - - items = append(items, item) - } + _, err = testEngine.Insert(&items) + assert.NoError(t, err) - _, err := testEngine.Insert(&items) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(items) - } + items = make([]NullType, 0, 7) + err = testEngine.Find(&items) + assert.NoError(t, err) + assert.EqualValues(t, 7, len(items)) } func TestNullStructUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) _, err := testEngine.Insert([]NullType{ @@ -173,69 +145,47 @@ func TestNullStructUpdate(t *testing.T) { if true { // 测试可插入NULL item := new(NullType) - item.Age = sql.NullInt64{23, true} - item.Height = sql.NullFloat64{0, false} // update to NULL + item.Age = sql.NullInt64{Int64: 23, Valid: true} + item.Height = sql.NullFloat64{Float64: 0, Valid: false} // update to NULL affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item) - if err != nil { - t.Error(err) - panic(err) - } - if affected != 1 { - err := errors.New("update failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) } if true { // 测试In update item := new(NullType) - item.Age = sql.NullInt64{23, true} + item.Age = sql.NullInt64{Int64: 23, Valid: true} affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) - if err != nil { - t.Error(err) - panic(err) - } - if affected != 2 { - err := errors.New("update failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 2, affected) } if true { // 测试where item := new(NullType) - item.Name = sql.NullString{"nullname", true} - item.IsMan = sql.NullBool{true, true} - item.Age = sql.NullInt64{34, true} + item.Name = sql.NullString{String: "nullname", Valid: true} + item.IsMan = sql.NullBool{Bool: true, Valid: true} + item.Age = sql.NullInt64{Int64: 34, Valid: true} _, err := testEngine.Where("age > ?", 34).Update(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } if true { // 修改全部时,插入空值 item := &NullType{ - Name: sql.NullString{"winxxp", true}, - Age: sql.NullInt64{30, true}, - Height: sql.NullFloat64{1.72, true}, + Name: sql.NullString{String: "winxxp", Valid: true}, + Age: sql.NullInt64{Int64: 30, Valid: true}, + Height: sql.NullFloat64{Float64: 1.72, Valid: true}, // IsMan: sql.NullBool{true, true}, } _, err := testEngine.AllCols().ID(6).Update(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) } - } func TestNullStructFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) _, err := testEngine.Insert([]NullType{ @@ -269,68 +219,38 @@ func TestNullStructFind(t *testing.T) { if true { item := new(NullType) has, err := testEngine.ID(1).Get(item) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("no find id 1")) - panic(err) - } - fmt.Println(item) - if item.Id != 1 || item.Name.Valid || item.Age.Valid || item.Height.Valid || - item.IsMan.Valid { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, item.Id, 1) + assert.False(t, item.Name.Valid) + assert.False(t, item.Age.Valid) + assert.False(t, item.Height.Valid) + assert.False(t, item.IsMan.Valid) } if true { item := new(NullType) item.Id = 2 - has, err := testEngine.Get(item) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("no find id 2")) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) + assert.True(t, has) } if true { item := make([]NullType, 0) - err := testEngine.ID(2).Find(&item) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(item) + assert.NoError(t, err) } if true { item := make([]NullType, 0) - err := testEngine.Asc("age").Find(&item) - if err != nil { - t.Error(err) - panic(err) - } - - for k, v := range item { - fmt.Println(k, v) - } + assert.NoError(t, err) } } func TestNullStructIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) if true { @@ -340,65 +260,45 @@ func TestNullStructIterate(t *testing.T) { fmt.Println(i, nultype) return nil }) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } } func TestNullStructCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) if true { item := new(NullType) - total, err := testEngine.Where("age IS NOT NULL").Count(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(total) + _, err := testEngine.Where("age IS NOT NULL").Count(item) + assert.NoError(t, err) } } func TestNullStructRows(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) item := new(NullType) rows, err := testEngine.Where("id > ?", 1).Rows(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer rows.Close() for rows.Next() { err = rows.Scan(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) } } func TestNullStructDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) item := new(NullType) _, err := testEngine.ID(1).Delete(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Where("id > ?", 1).Delete(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } diff --git a/types_test.go b/integrations/types_test.go similarity index 75% rename from types_test.go rename to integrations/types_test.go index 6510a01..1948558 100644 --- a/types_test.go +++ b/integrations/types_test.go @@ -2,19 +2,23 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" "fmt" "testing" + "github.com/xormplus/xorm" + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/schemas" + "github.com/stretchr/testify/assert" - "github.com/xormplus/core" ) func TestArrayField(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ArrayStruct struct { Id int64 @@ -77,7 +81,7 @@ func TestArrayField(t *testing.T) { } func TestGetBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Varbinary struct { Data []byte `xorm:"VARBINARY(250)"` @@ -116,40 +120,47 @@ type ConvConfig struct { } func (s *ConvConfig) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + if data == nil { + s = nil + return nil + } + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *ConvConfig) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + if s == nil { + return nil, nil + } + return json.DefaultJSONHandler.Marshal(s) } type SliceType []*ConvConfig func (s *SliceType) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *SliceType) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + return json.DefaultJSONHandler.Marshal(s) } type ConvStruct struct { Conv ConvString Conv2 *ConvString Cfg1 ConvConfig - Cfg2 *ConvConfig `xorm:"TEXT"` - Cfg3 core.Conversion `xorm:"BLOB"` + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 convert.Conversion `xorm:"BLOB"` Slice SliceType } -func (c *ConvStruct) BeforeSet(name string, cell Cell) { +func (c *ConvStruct) BeforeSet(name string, cell xorm.Cell) { if name == "cfg3" || name == "Cfg3" { c.Cfg3 = new(ConvConfig) } } func TestConversion(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) c := new(ConvStruct) assert.NoError(t, testEngine.DropTables(c)) @@ -181,6 +192,30 @@ func TestConversion(t *testing.T) { assert.EqualValues(t, 2, len(c1.Slice)) assert.EqualValues(t, *c.Slice[0], *c1.Slice[0]) assert.EqualValues(t, *c.Slice[1], *c1.Slice[1]) + + cnt, err := testEngine.Where("1=1").Delete(new(ConvStruct)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + c.Cfg2 = nil + + _, err = testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(ConvStruct) + has, err = testEngine.Get(c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "prefix---tttt", string(c2.Conv)) + assert.NotNil(t, c2.Conv2) + assert.EqualValues(t, "prefix---"+s, *c2.Conv2) + assert.EqualValues(t, c.Cfg1, c2.Cfg1) + assert.Nil(t, c2.Cfg2) + assert.NotNil(t, c2.Cfg3) + assert.EqualValues(t, *c.Cfg3.(*ConvConfig), *c2.Cfg3.(*ConvConfig)) + assert.EqualValues(t, 2, len(c2.Slice)) + assert.EqualValues(t, *c.Slice[0], *c2.Slice[0]) + assert.EqualValues(t, *c.Slice[1], *c2.Slice[1]) } type MyInt int @@ -209,7 +244,7 @@ type MyStruct struct { } func TestCustomType1(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&MyStruct{}) assert.NoError(t, err) @@ -267,14 +302,14 @@ type Status struct { } var ( - _ core.Conversion = &Status{} - Registed Status = Status{"Registed", "white"} - Approved Status = Status{"Approved", "green"} - Removed Status = Status{"Removed", "red"} - Statuses map[string]Status = map[string]Status{ - Registed.Name: Registed, - Approved.Name: Approved, - Removed.Name: Removed, + _ convert.Conversion = &Status{} + Registered = Status{"Registered", "white"} + Approved = Status{"Approved", "green"} + Removed = Status{"Removed", "red"} + Statuses = map[string]Status{ + Registered.Name: Registered, + Approved.Name: Approved, + Removed.Name: Removed, } ) @@ -282,9 +317,8 @@ func (s *Status) FromDB(bytes []byte) error { if r, ok := Statuses[string(bytes)]; ok { *s = r return nil - } else { - return errors.New("no this data") } + return errors.New("no this data") } func (s *Status) ToDB() ([]byte, error) { @@ -298,7 +332,7 @@ type UserCus struct { } func TestCustomType2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) var uc UserCus err := testEngine.CreateTables(&uc) @@ -311,18 +345,18 @@ func TestCustomType2(t *testing.T) { session := testEngine.NewSession() defer session.Close() - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") assert.NoError(t, err) } - cnt, err := session.Insert(&UserCus{1, "xlw", Registed}) + cnt, err := session.Insert(&UserCus{1, "xlw", Registered}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -335,7 +369,7 @@ func TestCustomType2(t *testing.T) { fmt.Println(user) users := make([]UserCus, 0) - err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registed").Find(&users) + err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registered").Find(&users) assert.NoError(t, err) assert.EqualValues(t, 1, len(users)) diff --git a/interface.go b/interface.go index fb16c06..92cf4b1 100644 --- a/interface.go +++ b/interface.go @@ -10,7 +10,11 @@ import ( "reflect" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" ) // Interface defines the interface which Engine, EngineGroup and Session will implementate. @@ -28,7 +32,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error - Exec(sqlOrAgrs ...interface{}) (sql.Result, error) + Exec(sqlOrArgs ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error FindAndCount(interface{}, ...interface{}) (int64, error) @@ -50,13 +54,14 @@ type Interface interface { Omit(columns ...string) *Session OrderBy(order string) *Session Ping() error - QueryBytes(sqlOrAgrs ...interface{}) (resultsSlice []map[string][]byte, err error) + QueryBytes(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) QueryValue(sqlOrArgs ...interface{}) ([]map[string]Value, error) QueryResult(sqlOrArgs ...interface{}) (result *ResultValue) Rows(bean interface{}) (*Rows, error) SetExpr(string, interface{}) *Session + Select(string) *Session SQL(interface{}, ...interface{}) *Session Sum(bean interface{}, colName string) (float64, error) SumInt(bean interface{}, colName string) (int64, error) @@ -78,39 +83,41 @@ type EngineInterface interface { ClearCache(...interface{}) error Context(context.Context) *Session CreateTables(...interface{}) error - DBMetas() ([]*core.Table, error) - Dialect() core.Dialect + DBMetas() ([]*schemas.Table, error) + Dialect() dialects.Dialect + DriverName() string DropTables(...interface{}) error - DumpAllToFile(fp string, tp ...core.DbType) error - GetCacher(string) core.Cacher - GetColumnMapper() core.IMapper - GetDefaultCacher() core.Cacher - GetTableMapper() core.IMapper + DumpAllToFile(fp string, tp ...schemas.DBType) error + GetCacher(string) caches.Cacher + GetColumnMapper() names.Mapper + GetDefaultCacher() caches.Cacher + GetTableMapper() names.Mapper GetTZDatabase() *time.Location GetTZLocation() *time.Location - MapCacher(interface{}, core.Cacher) error + ImportFile(fp string) ([]sql.Result, error) + MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session Quote(string) string - SetCacher(string, core.Cacher) + SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) - SetColumnMapper(core.IMapper) - SetDefaultCacher(core.Cacher) - SetLogger(logger core.ILogger) - SetLogLevel(core.LogLevel) - SetMapper(core.IMapper) + SetColumnMapper(names.Mapper) + SetDefaultCacher(caches.Cacher) + SetLogger(logger interface{}) + SetLogLevel(log.LogLevel) + SetMapper(names.Mapper) SetMaxOpenConns(int) SetMaxIdleConns(int) + SetQuotePolicy(dialects.QuotePolicy) SetSchema(string) - SetTableMapper(core.IMapper) + SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) - ShowExecTime(...bool) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error StoreEngine(storeEngine string) *Session - TableInfo(bean interface{}) *Table + TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/internal/json/json.go b/internal/json/json.go new file mode 100644 index 0000000..c9a2eb4 --- /dev/null +++ b/internal/json/json.go @@ -0,0 +1,31 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import "encoding/json" + +// JSONInterface represents an interface to handle json data +type JSONInterface interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error +} + +var ( + // DefaultJSONHandler default json handler + DefaultJSONHandler JSONInterface = StdJSON{} +) + +// StdJSON implements JSONInterface via encoding/json +type StdJSON struct{} + +// Marshal implements JSONInterface +func (StdJSON) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// Unmarshal implements JSONInterface +func (StdJSON) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} diff --git a/internal/statements/cache.go b/internal/statements/cache.go new file mode 100644 index 0000000..5b2b895 --- /dev/null +++ b/internal/statements/cache.go @@ -0,0 +1,79 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" +) + +func (statement *Statement) ConvertIDSQL(sqlStr string) string { + if statement.RefTable != nil { + cols := statement.RefTable.PKColumns() + if len(cols) == 0 { + return "" + } + + colstrs := statement.joinColumns(cols, false) + sqls := utils.SplitNNoCase(sqlStr, " from ", 2) + if len(sqls) != 2 { + return "" + } + + var top string + pLimitN := statement.LimitN + if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { + top = fmt.Sprintf("TOP %d ", *pLimitN) + } + + newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) + return newsql + } + return "" +} + +func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + return "", "" + } + + colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) + if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + colstrs, statement.quote(statement.TableName())) + } + return "", "" + } + + var whereStr = sqls[1] + + // TODO: for postgres only, if any other database? + var paraStr string + if statement.dialect.URI().DBType == schemas.POSTGRES { + paraStr = "$" + } else if statement.dialect.URI().DBType == schemas.MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } + } + } + + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + colstrs, statement.quote(statement.TableName()), + whereStr) +} diff --git a/internal/statements/column_map.go b/internal/statements/column_map.go new file mode 100644 index 0000000..86e3921 --- /dev/null +++ b/internal/statements/column_map.go @@ -0,0 +1,66 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "strings" + + "github.com/xormplus/xorm/schemas" +) + +type columnMap []string + +func (m columnMap) Contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func (m columnMap) Len() int { + return len(m) +} + +func (m columnMap) IsEmpty() bool { + return len(m) == 0 +} + +func (m *columnMap) Add(colName string) bool { + if m.Contain(colName) { + return false + } + *m = append(*m, colName) + return true +} + +func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) { + if len(m) == 0 { + return false, false + } + + n := len(col.Name) + + for mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, col.Name) { + return m[mk], true + } + } + + return false, false +} diff --git a/statement_exprparam.go b/internal/statements/expr_param.go similarity index 64% rename from statement_exprparam.go rename to internal/statements/expr_param.go index 1ad43be..67c6c69 100644 --- a/statement_exprparam.go +++ b/internal/statements/expr_param.go @@ -1,14 +1,15 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. +// Copyright 2019 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" "strings" "github.com/xormplus/builder" + "github.com/xormplus/xorm/schemas" ) type ErrUnsupportedExprType struct { @@ -25,22 +26,22 @@ type exprParam struct { } type exprParams struct { - colNames []string - args []interface{} + ColNames []string + Args []interface{} } func (exprs *exprParams) Len() int { - return len(exprs.colNames) + return len(exprs.ColNames) } func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.colNames = append(exprs.colNames, colName) - exprs.args = append(exprs.args, arg) + exprs.ColNames = append(exprs.ColNames, colName) + exprs.Args = append(exprs.Args, arg) } -func (exprs *exprParams) isColExist(colName string) bool { - for _, name := range exprs.colNames { - if strings.EqualFold(trimQuote(name), trimQuote(colName)) { +func (exprs *exprParams) IsColExist(colName string) bool { + for _, name := range exprs.ColNames { + if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { return true } } @@ -48,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool { } func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.colNames { + for i, name := range exprs.ColNames { if strings.EqualFold(name, colName) { - return exprParam{name, exprs.args[i]}, true + return exprParam{name, exprs.Args[i]}, true } } return exprParam{}, false } -func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.args { +func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs.Args { switch arg := expr.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { @@ -69,12 +70,20 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { if _, err := w.WriteString(")"); err != nil { return err } - default: + case string: + if arg == "" { + arg = "''" + } if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { return err } + default: + if _, err := w.WriteString("?"); err != nil { + return err + } + w.Append(arg) } - if i != len(exprs.args)-1 { + if i != len(exprs.Args)-1 { if _, err := w.WriteString(","); err != nil { return err } @@ -84,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { } func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.colNames { + for i, colName := range exprs.ColNames { if _, err := w.WriteString(colName); err != nil { return err } @@ -92,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } - switch arg := exprs.args[i].(type) { + switch arg := exprs.Args[i].(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -104,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } default: - w.Append(exprs.args[i]) + w.Append(exprs.Args[i]) } - if i+1 != len(exprs.colNames) { + if i+1 != len(exprs.ColNames) { if _, err := w.WriteString(","); err != nil { return err } diff --git a/internal/statements/insert.go b/internal/statements/insert.go new file mode 100644 index 0000000..b7840e4 --- /dev/null +++ b/internal/statements/insert.go @@ -0,0 +1,207 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "github.com/xormplus/builder" + "github.com/xormplus/xorm/schemas" +) + +func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schemas.Table) error { + if statement.dialect.URI().DBType == schemas.MSSQL && len(table.AutoIncrement) > 0 { + if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { + return err + } + if _, err := buf.WriteString(table.AutoIncrement); err != nil { + return err + } + } + return nil +} + +// GenInsertSQL generates insert beans SQL +func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + table = statement.RefTable + tableName = statement.TableName() + ) + + if _, err := buf.WriteString("INSERT INTO "); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return "", nil, err + } + + if len(colNames) <= 0 { + if statement.dialect.URI().DBType == schemas.MYSQL { + if _, err := buf.WriteString(" VALUES ()"); err != nil { + return "", nil, err + } + } else { + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } + if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { + return "", nil, err + } + } + } else { + if _, err := buf.WriteString(" ("); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } + + if statement.Conds().IsValid() { + if _, err := buf.WriteString(" SELECT "); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(" FROM "); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(" WHERE "); err != nil { + return "", nil, err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString(" VALUES ("); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + } + } + + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return "", nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} + +// GenInsertMapSQL generates insert map SQL +func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + return "", nil, err + } + + // if insert where + if statement.Conds().IsValid() { + if _, err := buf.WriteString(") SELECT "); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + + if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString(") VALUES ("); err != nil { + return "", nil, err + } + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} diff --git a/internal/statements/pk.go b/internal/statements/pk.go new file mode 100644 index 0000000..7004597 --- /dev/null +++ b/internal/statements/pk.go @@ -0,0 +1,79 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "reflect" + + "github.com/xormplus/builder" + "github.com/xormplus/xorm/schemas" +) + +var ( + ptrPkType = reflect.TypeOf(&schemas.PK{}) + pkType = reflect.TypeOf(schemas.PK{}) + stringType = reflect.TypeOf("") + intType = reflect.TypeOf(int64(0)) + uintType = reflect.TypeOf(uint64(0)) +) + +// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" +func (statement *Statement) ID(id interface{}) *Statement { + switch t := id.(type) { + case *schemas.PK: + statement.idParam = *t + case schemas.PK: + statement.idParam = t + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + statement.idParam = schemas.PK{id} + default: + idValue := reflect.ValueOf(id) + idType := idValue.Type() + + switch idType.Kind() { + case reflect.String: + statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + statement.idParam = schemas.PK{idValue.Convert(intType).Interface()} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()} + case reflect.Slice: + if idType.ConvertibleTo(pkType) { + statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK) + } + case reflect.Ptr: + if idType.ConvertibleTo(ptrPkType) { + statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK) + } + } + } + + if statement.idParam == nil { + statement.LastError = fmt.Errorf("ID param %#v is not supported", id) + } + + return statement +} + +func (statement *Statement) ProcessIDParam() error { + if statement.idParam == nil || statement.RefTable == nil { + return nil + } + + if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) { + fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam) + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(statement.idParam), + ) + } + + for i, col := range statement.RefTable.PKColumns() { + var colName = statement.colName(col, statement.TableName()) + statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]}) + } + return nil +} diff --git a/internal/statements/query.go b/internal/statements/query.go new file mode 100644 index 0000000..c21ea66 --- /dev/null +++ b/internal/statements/query.go @@ -0,0 +1,534 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/xormplus/builder" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" +) + +func (statement *Statement) genSelectSql(dialect dialects.Dialect, rownumber string) string { + + var sql = statement.RawSQL + var orderBys = statement.OrderStr + pLimitN := statement.LimitN + + if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { + if statement.Start > 0 { + sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, statement.LimitN, statement.Start) + if pLimitN != nil { + sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, *pLimitN, statement.Start) + } else { + sql = fmt.Sprintf("%v LIMIT 0 OFFSET %v", sql, *pLimitN) + } + } else if pLimitN != nil { + sql = fmt.Sprintf("%v LIMIT %v", sql, statement.LimitN) + } + } else if dialect.URI().DBType == schemas.ORACLE { + if statement.Start != 0 || pLimitN != nil { + sql = fmt.Sprintf("SELECT aat.* FROM (SELECT at.*,ROWNUM %v FROM (%v) at WHERE ROWNUM <= %d) aat WHERE %v > %d", + rownumber, sql, statement.Start+*pLimitN, rownumber, statement.Start) + } + } else { + keepSelect := false + var fullQuery string + if statement.Start > 0 { + fullQuery = fmt.Sprintf("SELECT sq.* FROM (SELECT ROW_NUMBER() OVER (ORDER BY %v) AS %v,", orderBys, rownumber) + } else if pLimitN != nil { + fullQuery = fmt.Sprintf("SELECT TOP %d", *pLimitN) + } else { + keepSelect = true + } + + if !keepSelect { + expr := `^\s*SELECT\s*` + reg, err := regexp.Compile(expr) + if err != nil { + fmt.Println(err) + } + sql = strings.ToUpper(sql) + if reg.MatchString(sql) { + str := reg.FindAllString(sql, -1) + fullQuery = fmt.Sprintf("%v %v", fullQuery, sql[len(str[0]):]) + } + } + + if statement.Start > 0 { + // T-SQL offset starts with 1, not like MySQL with 0; + if pLimitN != nil { + fullQuery = fmt.Sprintf("%v) AS sq WHERE %v BETWEEN %d AND %d", fullQuery, rownumber, statement.Start+1, statement.Start+*pLimitN) + } else { + fullQuery = fmt.Sprintf("%v) AS sq WHERE %v >= %d", fullQuery, rownumber, statement.Start+1) + } + } else { + fullQuery = fmt.Sprintf("%v ORDER BY %v", fullQuery, orderBys) + } + + if keepSelect { + if len(orderBys) > 0 { + sql = fmt.Sprintf("%v ORDER BY %v", sql, orderBys) + } + } else { + sql = fullQuery + } + } + + return sql +} + +func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { + if len(sqlOrArgs) > 0 { + return statement.ConvertSQLOrArgs(sqlOrArgs...) + } + + if statement.RawSQL != "" { + var dialect = statement.dialect + rownumber := "xorm" + utils.NewShortUUID().String() + sql := statement.genSelectSql(dialect, rownumber) + + params := statement.RawParams + i := len(params) + + // var result []map[string]interface{} + // var err error + if i == 1 { + vv := reflect.ValueOf(params[0]) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return sql, params, nil + } else { + sqlStr1, param, _ := core.MapToSlice(sql, params[0]) + return sqlStr1, param, nil + } + } else { + return sql, params, nil + } + // return session.statement.RawSQL, session.statement.RawParams, nil + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + args := append(statement.joinArgs, condArgs...) + + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + statement.SetRefBean(bean) + + var sumStrs = make([]string, 0, len(columns)) + for _, colName := range columns { + if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { + colName = statement.quote(colName) + } else { + colName = statement.ReplaceQuote(colName) + } + sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) + } + sumSelect := strings.Join(sumStrs, ", ") + + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + + sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { + v := rValue(bean) + isStruct := v.Kind() == reflect.Struct + if isStruct { + statement.SetRefBean(bean) + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + // TODO: always generate column names, not use * even if join + if len(statement.JoinStr) == 0 { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } + } + } + } + + if len(columnStr) == 0 { + columnStr = "*" + } + + if isStruct { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + } else { + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + } + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +// GenCountSQL generates the SQL for counting +func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.SetRefBean(beans[0]) + if err := statement.mergeConds(beans[0]); err != nil { + return "", nil, err + } + } + + var selectSQL = statement.SelectStr + if len(selectSQL) <= 0 { + if statement.IsDistinct { + selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) + } else if statement.ColumnStr() != "" { + selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr()) + } else { + selectSQL = "count(*)" + } + } + sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { + var ( + distinct string + dialect = statement.dialect + quote = statement.quote + fromStr = " FROM " + top, mssqlCondi, whereStr string + ) + if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { + distinct = "DISTINCT " + } + + condSQL, condArgs, err := statement.GenCondSQL(statement.cond) + if err != nil { + return "", nil, err + } + if len(condSQL) > 0 { + whereStr = " WHERE " + condSQL + } + + if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + fromStr += statement.TableName() + } else { + fromStr += quote(statement.TableName()) + } + + if statement.TableAlias != "" { + if dialect.URI().DBType == schemas.ORACLE { + fromStr += " " + quote(statement.TableAlias) + } else { + fromStr += " AS " + quote(statement.TableAlias) + } + } + if statement.JoinStr != "" { + fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) + } + + pLimitN := statement.LimitN + if dialect.URI().DBType == schemas.MSSQL { + if pLimitN != nil { + LimitNValue := *pLimitN + top = fmt.Sprintf("TOP %d ", LimitNValue) + } + if statement.Start > 0 { + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.needTableName() { + if len(statement.TableAlias) > 0 { + column = statement.TableAlias + "." + column + } else { + column = statement.TableName() + "." + column + } + } + + var orderStr string + if needOrderBy && len(statement.OrderStr) > 0 { + orderStr = " ORDER BY " + statement.OrderStr + } + + var groupStr string + if len(statement.GroupByStr) > 0 { + groupStr = " GROUP BY " + statement.GroupByStr + } + mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", + column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) + } + } + + var buf strings.Builder + fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + if len(mssqlCondi) > 0 { + if len(whereStr) > 0 { + fmt.Fprint(&buf, " AND ", mssqlCondi) + } else { + fmt.Fprint(&buf, " WHERE ", mssqlCondi) + } + } + + if statement.GroupByStr != "" { + fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) + } + if statement.HavingStr != "" { + fmt.Fprint(&buf, " ", statement.HavingStr) + } + if needOrderBy && statement.OrderStr != "" { + fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) + } + if needLimit { + if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { + if statement.Start > 0 { + if pLimitN != nil { + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) + } else { + fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) + } + } else if pLimitN != nil { + fmt.Fprint(&buf, " LIMIT ", *pLimitN) + } + } else if dialect.URI().DBType == schemas.ORACLE { + if statement.Start != 0 || pLimitN != nil { + oldString := buf.String() + buf.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) + } + } + } + if statement.IsForUpdate { + return dialect.ForUpdateSQL(buf.String()), condArgs, nil + } + + return buf.String(), condArgs, nil +} + +func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var joinStr string + var err error + if len(bean) == 0 { + tableName := statement.TableName() + if len(tableName) <= 0 { + return "", nil, ErrTableNotFound + } + + tableName = statement.quote(tableName) + if len(statement.JoinStr) > 0 { + joinStr = statement.JoinStr + } + + if statement.Conds().IsValid() { + condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) + if err != nil { + return "", nil, err + } + + if statement.dialect.URI().DBType == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) + } else if statement.dialect.URI().DBType == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) + } + args = condArgs + } else { + if statement.dialect.URI().DBType == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) + } else if statement.dialect.URI().DBType == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) + } + args = []interface{}{} + } + } else { + beanValue := reflect.ValueOf(bean[0]) + if beanValue.Kind() != reflect.Ptr { + return "", nil, errors.New("needs a pointer") + } + + if beanValue.Elem().Kind() == reflect.Struct { + if err := statement.SetRefBean(bean[0]); err != nil { + return "", nil, err + } + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + statement.Limit(1) + sqlStr, args, err = statement.GenGetSQL(bean[0]) + if err != nil { + return "", nil, err + } + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var err error + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + statement.cond = statement.cond.And(autoCond) + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + args = append(statement.joinArgs, condArgs...) + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go new file mode 100644 index 0000000..42cf223 --- /dev/null +++ b/internal/statements/statement.go @@ -0,0 +1,996 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "github.com/xormplus/builder" + "github.com/xormplus/xorm/contexts" + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" + "github.com/xormplus/xorm/tags" +) + +var ( + // ErrConditionType condition type unsupported + ErrConditionType = errors.New("Unsupported condition type") + // ErrUnSupportedSQLType parameter of SQL is not supported + ErrUnSupportedSQLType = errors.New("Unsupported sql type") + // ErrUnSupportedType unsupported error + ErrUnSupportedType = errors.New("Unsupported type error") + // ErrTableNotFound table not found error + ErrTableNotFound = errors.New("Table not found") +) + +// Statement save all the sql info for executing SQL +type Statement struct { + RefTable *schemas.Table + dialect dialects.Dialect + defaultTimeZone *time.Location + tagParser *tags.Parser + Start int + LimitN *int + idParam schemas.PK + OrderStr string + JoinStr string + joinArgs []interface{} + GroupByStr string + HavingStr string + SelectStr string + useAllCols bool + AltTableName string + tableName string + RawSQL string + RawParams []interface{} + UseCascade bool + UseAutoJoin bool + StoreEngine string + Charset string + UseCache bool + UseAutoTime bool + NoAutoCondition bool + IsDistinct bool + IsForUpdate bool + TableAlias string + allUseBool bool + CheckVersion bool + unscoped bool + ColumnMap columnMap + OmitColumnMap columnMap + MustColumnMap map[string]bool + NullableMap map[string]bool + IncrColumns exprParams + DecrColumns exprParams + ExprColumns exprParams + cond builder.Cond + BufferSize int + Context contexts.ContextCache + LastError error +} + +// NewStatement creates a new statement +func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { + statement := &Statement{ + dialect: dialect, + tagParser: tagParser, + defaultTimeZone: defaultTimeZone, + } + statement.Reset() + return statement +} + +func (statement *Statement) SetTableName(tableName string) { + statement.tableName = tableName +} + +func (statement *Statement) omitStr() string { + return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") +} + +// GenRawSQL generates correct raw sql +func (statement *Statement) GenRawSQL() string { + return statement.ReplaceQuote(statement.RawSQL) +} + +func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { + condSQL, condArgs, err := builder.ToSQL(condOrBuilder) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(condSQL), condArgs, nil +} + +func (statement *Statement) ReplaceQuote(sql string) string { + if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || + statement.dialect.URI().DBType == schemas.SQLITE { + return sql + } + return statement.dialect.Quoter().Replace(sql) +} + +func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { + statement.Context = ctxCache +} + +// Init reset all the statement's fields +func (statement *Statement) Reset() { + statement.RefTable = nil + statement.Start = 0 + statement.LimitN = nil + statement.OrderStr = "" + statement.UseCascade = true + statement.JoinStr = "" + statement.joinArgs = make([]interface{}, 0) + statement.GroupByStr = "" + statement.HavingStr = "" + statement.ColumnMap = columnMap{} + statement.OmitColumnMap = columnMap{} + statement.AltTableName = "" + statement.tableName = "" + statement.idParam = nil + statement.RawSQL = "" + statement.RawParams = make([]interface{}, 0) + statement.UseCache = true + statement.UseAutoTime = true + statement.NoAutoCondition = false + statement.IsDistinct = false + statement.IsForUpdate = false + statement.TableAlias = "" + statement.SelectStr = "" + statement.allUseBool = false + statement.useAllCols = false + statement.MustColumnMap = make(map[string]bool) + statement.NullableMap = make(map[string]bool) + statement.CheckVersion = true + statement.unscoped = false + statement.IncrColumns = exprParams{} + statement.DecrColumns = exprParams{} + statement.ExprColumns = exprParams{} + statement.cond = builder.NewCond() + statement.BufferSize = 0 + statement.Context = nil + statement.LastError = nil +} + +// NoAutoCondition if you do not want convert bean's field as query condition, then use this function +func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { + statement.NoAutoCondition = true + if len(no) > 0 { + statement.NoAutoCondition = no[0] + } + return statement +} + +// Alias set the table alias +func (statement *Statement) Alias(alias string) *Statement { + statement.TableAlias = alias + return statement +} + +// SQL adds raw sql statement +func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case (*builder.Builder): + var err error + statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() + if err != nil { + statement.LastError = err + } + case string: + statement.RawSQL = query.(string) + statement.RawParams = args + default: + statement.LastError = ErrUnSupportedSQLType + } + + return statement +} + +// Where add Where statement +func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { + return statement.And(query, args...) +} + +func (statement *Statement) quote(s string) string { + return statement.dialect.Quoter().Quote(s) +} + +// And add Where & and statement +func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case string: + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.And(cond) + case map[string]interface{}: + queryMap := query.(map[string]interface{}) + newMap := make(map[string]interface{}) + for k, v := range queryMap { + newMap[statement.quote(k)] = v + } + statement.cond = statement.cond.And(builder.Eq(newMap)) + case builder.Cond: + cond := query.(builder.Cond) + statement.cond = statement.cond.And(cond) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.And(vv) + } + } + default: + statement.LastError = ErrConditionType + } + + return statement +} + +// Or add Where & Or statement +func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case string: + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.Or(cond) + case map[string]interface{}: + cond := builder.Eq(query.(map[string]interface{})) + statement.cond = statement.cond.Or(cond) + case builder.Cond: + cond := query.(builder.Cond) + statement.cond = statement.cond.Or(cond) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.Or(vv) + } + } + default: + // TODO: not support condition type + } + return statement +} + +// In generate "Where column IN (?) " statement +func (statement *Statement) In(column string, args ...interface{}) *Statement { + in := builder.In(statement.quote(column), args...) + statement.cond = statement.cond.And(in) + return statement +} + +// NotIn generate "Where column NOT IN (?) " statement +func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { + notIn := builder.NotIn(statement.quote(column), args...) + statement.cond = statement.cond.And(notIn) + return statement +} + +func (statement *Statement) SetRefValue(v reflect.Value) error { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) + if err != nil { + return err + } + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) + return nil +} + +func rValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} + +func (statement *Statement) SetRefBean(bean interface{}) error { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) + if err != nil { + return err + } + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) + return nil +} + +func (statement *Statement) needTableName() bool { + return len(statement.JoinStr) > 0 +} + +func (statement *Statement) colName(col *schemas.Column, tableName string) string { + if statement.needTableName() { + var nm = tableName + if len(statement.TableAlias) > 0 { + nm = statement.TableAlias + } + return statement.quote(nm) + "." + statement.quote(col.Name) + } + return statement.quote(col.Name) +} + +// TableName return current tableName +func (statement *Statement) TableName() string { + if statement.AltTableName != "" { + return statement.AltTableName + } + + return statement.tableName +} + +// Incr Generate "Update ... Set column = column + arg" statement +func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { + if len(arg) > 0 { + statement.IncrColumns.addParam(column, arg[0]) + } else { + statement.IncrColumns.addParam(column, 1) + } + return statement +} + +// Decr Generate "Update ... Set column = column - arg" statement +func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { + if len(arg) > 0 { + statement.DecrColumns.addParam(column, arg[0]) + } else { + statement.DecrColumns.addParam(column, 1) + } + return statement +} + +// SetExpr Generate "Update ... Set column = {expression}" statement +func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { + if e, ok := expression.(string); ok { + statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + } else { + statement.ExprColumns.addParam(column, expression) + } + return statement +} + +// Distinct generates "DISTINCT col1, col2 " statement +func (statement *Statement) Distinct(columns ...string) *Statement { + statement.IsDistinct = true + statement.Cols(columns...) + return statement +} + +// ForUpdate generates "SELECT ... FOR UPDATE" statement +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + +// Select replace select +func (statement *Statement) Select(str string) *Statement { + statement.SelectStr = statement.ReplaceQuote(str) + return statement +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + +// Cols generate "col1, col2" statement +func (statement *Statement) Cols(columns ...string) *Statement { + cols := col2NewCols(columns...) + for _, nc := range cols { + statement.ColumnMap.Add(nc) + } + return statement +} + +func (statement *Statement) ColumnStr() string { + return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") +} + +// AllCols update use only: update all columns +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + +// MustCols update use only: must update columns +func (statement *Statement) MustCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.MustColumnMap[strings.ToLower(nc)] = true + } + return statement +} + +// UseBool indicates that use bool fields as update contents and query contiditions +func (statement *Statement) UseBool(columns ...string) *Statement { + if len(columns) > 0 { + statement.MustCols(columns...) + } else { + statement.allUseBool = true + } + return statement +} + +// Omit do not use the columns +func (statement *Statement) Omit(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.OmitColumnMap = append(statement.OmitColumnMap, nc) + } +} + +// Nullable Update use only: update columns to null when value is nullable and zero-value +func (statement *Statement) Nullable(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.NullableMap[strings.ToLower(nc)] = true + } +} + +// Top generate LIMIT limit statement +func (statement *Statement) Top(limit int) *Statement { + statement.Limit(limit) + return statement +} + +// Limit generate LIMIT start, limit statement +func (statement *Statement) Limit(limit int, start ...int) *Statement { + statement.LimitN = &limit + if len(start) > 0 { + statement.Start = start[0] + } + return statement +} + +// OrderBy generate "Order By order" statement +func (statement *Statement) OrderBy(order string) *Statement { + if len(statement.OrderStr) > 0 { + statement.OrderStr += ", " + } + statement.OrderStr += statement.ReplaceQuote(order) + return statement +} + +// Desc generate `ORDER BY xx DESC` +func (statement *Statement) Desc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " DESC") + } + statement.OrderStr = buf.String() + return statement +} + +// Asc provide asc order by query condition, the input parameters are columns. +func (statement *Statement) Asc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " ASC") + } + statement.OrderStr = buf.String() + return statement +} + +func (statement *Statement) Conds() builder.Cond { + return statement.cond +} + +// Table tempororily set table name, the parameter could be a string or a pointer of struct +func (statement *Statement) SetTable(tableNameOrBean interface{}) error { + v := rValue(tableNameOrBean) + t := v.Type() + if t.Kind() == reflect.Struct { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(v) + if err != nil { + return err + } + } + + statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) + return nil +} + +// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { + var buf strings.Builder + if len(statement.JoinStr) > 0 { + fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) + } else { + fmt.Fprintf(&buf, "%v JOIN ", joinOP) + } + + switch tp := tablename.(type) { + case builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + case *builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + if !utils.IsSubQuery(tbName) { + var buf strings.Builder + statement.dialect.Quoter().QuoteTo(&buf, tbName) + tbName = buf.String() + } + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) + } + + statement.JoinStr = buf.String() + statement.joinArgs = append(statement.joinArgs, args...) + return statement +} + +// tbName get some table's table name +func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { + if len(statement.AltTableName) > 0 { + return statement.AltTableName + } + + return table.Name +} + +// GroupBy generate "Group By keys" statement +func (statement *Statement) GroupBy(keys string) *Statement { + statement.GroupByStr = statement.ReplaceQuote(keys) + return statement +} + +// Having generate "Having conditions" statement +func (statement *Statement) Having(conditions string) *Statement { + statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) + return statement +} + +// Unscoped always disable struct tag "deleted" +func (statement *Statement) SetUnscoped() *Statement { + statement.unscoped = true + return statement +} + +func (statement *Statement) GetUnscoped() bool { + return statement.unscoped +} + +func (statement *Statement) genColumnStr() string { + if statement.RefTable == nil { + return "" + } + + var buf strings.Builder + columns := statement.RefTable.Columns() + + for _, col := range columns { + if statement.OmitColumnMap.Contain(col.Name) { + continue + } + + if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { + continue + } + + if col.MapType == schemas.ONLYTODB { + continue + } + + if buf.Len() != 0 { + buf.WriteString(", ") + } + + if statement.JoinStr != "" { + if statement.TableAlias != "" { + buf.WriteString(statement.TableAlias) + } else { + buf.WriteString(statement.TableName()) + } + + buf.WriteString(".") + } + + statement.dialect.Quoter().QuoteTo(&buf, col.Name) + } + + return buf.String() +} + +func (statement *Statement) GenCreateTableSQL() []string { + statement.RefTable.StoreEngine = statement.StoreEngine + statement.RefTable.Charset = statement.Charset + s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) + return s +} + +func (statement *Statement) GenIndexSQL() []string { + var sqls []string + tbName := statement.TableName() + for _, index := range statement.RefTable.Indexes { + if index.Type == schemas.IndexType { + sql := statement.dialect.CreateIndexSQL(tbName, index) + sqls = append(sqls, sql) + } + } + return sqls +} + +func uniqueName(tableName, uqeName string) string { + return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) +} + +func (statement *Statement) GenUniqueSQL() []string { + var sqls []string + tbName := statement.TableName() + for _, index := range statement.RefTable.Indexes { + if index.Type == schemas.UniqueType { + sql := statement.dialect.CreateIndexSQL(tbName, index) + sqls = append(sqls, sql) + } + } + return sqls +} + +func (statement *Statement) GenDelIndexSQL() []string { + var sqls []string + tbName := statement.TableName() + idx := strings.Index(tbName, ".") + if idx > -1 { + tbName = tbName[idx+1:] + } + for _, index := range statement.RefTable.Indexes { + sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index)) + } + return sqls +} + +func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || + col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = statement.quote(nm) + "." + statement.quote(col.Name) + } else { + colName = statement.quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + if !strings.Contains(err.Error(), "is not valid") { + //engine.logger.Warn(err) + } + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, statement.CondDeleted(col)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil && !requiredField { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} + +func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { + return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) +} + +func (statement *Statement) mergeConds(bean interface{}) error { + if !statement.NoAutoCondition && statement.RefTable != nil { + var addedTableName = (len(statement.JoinStr) > 0) + autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) + if err != nil { + return err + } + statement.cond = statement.cond.And(autoCond) + } + + if err := statement.ProcessIDParam(); err != nil { + return err + } + return nil +} + +func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + + return statement.GenCondSQL(statement.cond) +} + +func (statement *Statement) quoteColumnStr(columnStr string) string { + columns := strings.Split(columnStr, ",") + return statement.dialect.Quoter().Join(columns, ",") +} + +func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + sql, args, err := convertSQLOrArgs(sqlOrArgs...) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(sql), args, nil +} + +func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + switch sqlOrArgs[0].(type) { + case string: + return sqlOrArgs[0].(string), sqlOrArgs[1:], nil + case *builder.Builder: + return sqlOrArgs[0].(*builder.Builder).ToSQL() + case builder.Builder: + bd := sqlOrArgs[0].(builder.Builder) + return bd.ToSQL() + } + + return "", nil, ErrUnSupportedType +} + +func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { + var colnames = make([]string, len(cols)) + for i, col := range cols { + if includeTableName { + colnames[i] = statement.quote(statement.TableName()) + + "." + statement.quote(col.Name) + } else { + colnames[i] = statement.quote(col.Name) + } + } + return strings.Join(colnames, ", ") +} + +// CondDeleted returns the conditions whether a record is soft deleted. +func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { + var colName = col.Name + if statement.JoinStr != "" { + var prefix string + if statement.TableAlias != "" { + prefix = statement.TableAlias + } else { + prefix = statement.TableName() + } + colName = statement.quote(prefix) + "." + statement.quote(col.Name) + } + var cond = builder.NewCond() + if col.SQLType.IsNumeric() { + cond = builder.Eq{colName: 0} + } else { + // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. + if statement.dialect.URI().DBType != schemas.MSSQL { + cond = builder.Eq{colName: utils.ZeroTime1} + } + } + + if col.Nullable { + cond = cond.Or(builder.IsNull{colName}) + } + + return cond +} diff --git a/statement_args.go b/internal/statements/statement_args.go similarity index 63% rename from statement_args.go rename to internal/statements/statement_args.go index 870ec67..43ea628 100644 --- a/statement_args.go +++ b/internal/statements/statement_args.go @@ -1,8 +1,8 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. +// Copyright 2019 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" @@ -11,7 +11,7 @@ import ( "time" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" ) func quoteNeeded(a interface{}) bool { @@ -77,30 +77,8 @@ func convertArg(arg interface{}, convertFunc func(string) string) string { const insertSelectPlaceHolder = true -func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { +func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { - case bool: - if statement.Engine.dialect.DBType() == core.MSSQL { - if argv { - if _, err := w.WriteString("1"); err != nil { - return err - } - } else { - if _, err := w.WriteString("0"); err != nil { - return err - } - } - } else { - if argv { - if _, err := w.WriteString("true"); err != nil { - return err - } - } else { - if _, err := w.WriteString("false"); err != nil { - return err - } - } - } case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -116,10 +94,18 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er if err := w.WriteByte('?'); err != nil { return err } - w.Append(arg) + if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { + if v { + w.Append(1) + } else { + w.Append(0) + } + } else { + w.Append(arg) + } } else { var convertFunc = convertStringSingleQuote - if statement.Engine.dialect.DBType() == core.MYSQL { + if statement.dialect.URI().DBType == schemas.MYSQL { convertFunc = convertString } if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { @@ -130,9 +116,9 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return nil } -func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error { +func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { - if err := statement.writeArg(w, arg); err != nil { + if err := statement.WriteArg(w, arg); err != nil { return err } @@ -144,27 +130,3 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{} } return nil } - -func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { - for i, colName := range cols { - if len(leftQuote) > 0 && colName[0] != '`' { - if _, err := w.WriteString(leftQuote); err != nil { - return err - } - } - if _, err := w.WriteString(colName); err != nil { - return err - } - if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { - if _, err := w.WriteString(rightQuote); err != nil { - return err - } - } - if i+1 != len(cols) { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} diff --git a/statement_test.go b/internal/statements/statement_test.go similarity index 56% rename from statement_test.go rename to internal/statements/statement_test.go index 9627dab..02eb238 100644 --- a/statement_test.go +++ b/internal/statements/statement_test.go @@ -2,17 +2,43 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "reflect" "strings" "testing" + "time" "github.com/stretchr/testify/assert" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" + "github.com/xormplus/xorm/tags" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + dialect dialects.Dialect + tagParser *tags.Parser ) +func TestMain(m *testing.M) { + var err error + dialect, err = dialects.OpenDialect("sqlite3", "./test.db") + if err != nil { + panic("unknow dialect") + } + + tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager()) + if tagParser == nil { + panic("tags parser is nil") + } + m.Run() +} + var colStrTests = []struct { omitColumn string onlyToDBColumnNdx int @@ -27,14 +53,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - for ndx, testCase := range colStrTests { - statement = createTestStatement() + statement, err := createTestStatement() + assert.NoError(t, err) if testCase.omitColumn != "" { statement.Omit(testCase.omitColumn) @@ -42,7 +63,7 @@ func TestColumnsStringGeneration(t *testing.T) { columns := statement.RefTable.Columns() if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB } actual := statement.genColumnStr() @@ -51,34 +72,7 @@ func TestColumnsStringGeneration(t *testing.T) { t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) } if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = core.TWOSIDES - } - } -} - -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES } } } @@ -88,7 +82,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() mapCols := make(map[string]bool) - cols := []*core.Column{ + cols := []*schemas.Column{ {Name: `ID`}, {Name: `IsDeleted`}, {Name: `Caption`}, @@ -122,7 +116,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StopTimer() mapCols := make(map[string]bool) - cols := []*core.Column{ + cols := []*schemas.Column{ {Name: `ID`}, {Name: `IsDeleted`}, {Name: `Caption`}, @@ -163,86 +157,40 @@ func (TestType) TableName() string { return "TestTable" } -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Init() - statement.Engine = engine - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Init() - statement.Engine = eg.Engine - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement +func createTestStatement() (*Statement, error) { + statement := NewStatement(dialect, tagParser, time.Local) + if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { + return nil, err } - return nil + return statement, nil } -func TestDistinctAndCols(t *testing.T) { - type DistinctAndCols struct { - Id int64 - Name string - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(DistinctAndCols)) +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() - cnt, err := testEngine.Insert(&DistinctAndCols{ - Name: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) + statement, err := createTestStatement() + if err != nil { + panic(err) + } - var names []string - err = testEngine.Table("distinct_and_cols").Cols("name").Distinct("name").Find(&names) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(names)) - assert.EqualValues(t, "test", names[0]) -} + testCase := colStrTests[0] -func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { - type TestOnlyFromDBField struct { - Id int64 `xorm:"PK"` - OnlyFromDBField string `xorm:"<-"` - OnlyToDBField string `xorm:"->"` - IngoreField string `xorm:"-"` + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped } - assertGetRecord := func() *TestOnlyFromDBField { - var record TestOnlyFromDBField - has, err := testEngine.Where("id = ?", 1).Get(&record) - assert.NoError(t, err) - assert.EqualValues(t, true, has) - assert.EqualValues(t, "", record.OnlyFromDBField) - return &record - + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped } - assert.NoError(t, prepareEngine()) - assertSync(t, new(TestOnlyFromDBField)) - - _, err := testEngine.Insert(&TestOnlyFromDBField{ - Id: 1, - OnlyFromDBField: "a", - OnlyToDBField: "b", - IngoreField: "c", - }) - assert.NoError(t, err) - - record := assertGetRecord() - record.OnlyFromDBField = "test" - testEngine.Update(record) - assertGetRecord() -} -func TestCol2NewColsWithQuote(t *testing.T) { - cols := []string{"f1", "f2", "t3.f3"} + b.StartTimer() - statement := createTestStatement() + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() - quotedCols := statement.col2NewColsWithQuote(cols...) - assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols) + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } + } } diff --git a/internal/statements/update.go b/internal/statements/update.go new file mode 100644 index 0000000..cacb65b --- /dev/null +++ b/internal/statements/update.go @@ -0,0 +1,295 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "time" + + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" +) + +func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) (bool, error) { + columnMap := statement.ColumnMap + omitColumnMap := statement.OmitColumnMap + unscoped := statement.unscoped + + if !includeVersion && col.IsVersion { + return false, nil + } + if col.IsCreated && !columnMap.Contain(col.Name) { + return false, nil + } + if !includeUpdated && col.IsUpdated { + return false, nil + } + if !includeAutoIncr && col.IsAutoIncrement { + return false, nil + } + if col.IsDeleted && !unscoped { + return false, nil + } + if omitColumnMap.Contain(col.Name) { + return false, nil + } + if len(columnMap) > 0 && !columnMap.Contain(col.Name) { + return false, nil + } + + if col.MapType == schemas.ONLYFROMDB { + return false, nil + } + + if statement.IncrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.DecrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.ExprColumns.IsColExist(col.Name) { + return false, nil + } + + return true, nil +} + +// BuildUpdates auto generating update columnes and values according a struct +func (statement *Statement) BuildUpdates(tableValue reflect.Value, + includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) ([]string, []interface{}, error) { + table := statement.RefTable + allUseBool := statement.allUseBool + useAllCols := statement.useAllCols + mustColumnMap := statement.MustColumnMap + nullableMap := statement.NullableMap + + var colNames = make([]string, 0) + var args = make([]interface{}, 0) + + for _, col := range table.Columns() { + ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update) + if err != nil { + return nil, nil, err + } + if !ok { + continue + } + + fieldValuePtr, err := col.ValueOfV(&tableValue) + if err != nil { + return nil, nil, err + } + + fieldValue := *fieldValuePtr + fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldType == nil { + continue + } + + requiredField := useAllCols + includeNil := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := getFlagForColumn(nullableMap, col); ok { + if b && col.Nullable && utils.IsZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + + var val interface{} + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + } + + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = nulType.Value() + if val == nil && !requiredField { + continue + } + } else { + if !col.SQLType.IsJson() { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) { + val = pkField.Interface() + } else { + continue + } + } else { + return nil, nil, errors.New("Not supported multiple primary keys") + } + } + } else { + // Blank struct could not be as update data + if requiredField || !utils.IsStructZero(fieldValue) { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, fmt.Errorf("mashal %v failed", fieldValue.Interface()) + } + if col.SQLType.IsText() { + val = string(bytes) + } else if col.SQLType.IsBlob() { + val = bytes + } + } else { + continue + } + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if !requiredField { + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldType.Kind() == reflect.Array { + if utils.IsArrayZero(fieldValue) { + continue + } + } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if fieldType.Kind() == reflect.Slice && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else if fieldType.Kind() == reflect.Array && + fieldType.Elem().Kind() == reflect.Uint8 { + val = fieldValue.Slice(0, 0).Interface() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + APPEND: + args = append(args, val) + colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) + } + + return colNames, args, nil +} diff --git a/internal/statements/values.go b/internal/statements/values.go new file mode 100644 index 0000000..2c19aa2 --- /dev/null +++ b/internal/statements/values.go @@ -0,0 +1,154 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "time" + + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/schemas" +) + +var ( + nullFloatType = reflect.TypeOf(sql.NullFloat64{}) +) + +// Value2Interface convert a field value of a struct to interface for puting into database +func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { + if fieldValue.CanAddr() { + if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + } + + if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + if nil == data { + return nil, nil + } + return string(data), nil + } + + fieldType := fieldValue.Type() + k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } + + switch k { + case reflect.Bool: + return fieldValue.Bool(), nil + case reflect.String: + return fieldValue.String(), nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + tf := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + return tf, nil + } else if fieldType.ConvertibleTo(nullFloatType) { + t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) + if !t.Valid { + return nil, nil + } + return t.Float64, nil + } + + if !col.SQLType.IsJson() { + // !! 增加支持driver.Valuer接口的结构,如sql.NullString + if v, ok := fieldValue.Interface().(driver.Valuer); ok { + return v.Value() + } + + fieldTable, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return nil, err + } + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) + return pkField.Interface(), nil + } + return nil, fmt.Errorf("no primary key for col %v", col.Name) + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return bytes, nil + } + return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) + case reflect.Complex64, reflect.Complex128: + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + case reflect.Array, reflect.Slice, reflect.Map: + if !fieldValue.IsValid() { + return fieldValue.Interface(), nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (k == reflect.Slice) && + (fieldValue.Type().Elem().Kind() == reflect.Uint8) { + bytes = fieldValue.Bytes() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + } + return bytes, nil + } + return nil, ErrUnSupportedType + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return int64(fieldValue.Uint()), nil + default: + return fieldValue.Interface(), nil + } +} diff --git a/internal/utils/name.go b/internal/utils/name.go new file mode 100644 index 0000000..f5fc3ff --- /dev/null +++ b/internal/utils/name.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" +) + +func IndexName(tableName, idxName string) string { + return fmt.Sprintf("IDX_%v_%v", tableName, idxName) +} diff --git a/internal/utils/reflect.go b/internal/utils/reflect.go new file mode 100644 index 0000000..3dad6bf --- /dev/null +++ b/internal/utils/reflect.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "reflect" +) + +func ReflectValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} diff --git a/internal/utils/slice.go b/internal/utils/slice.go new file mode 100644 index 0000000..8968570 --- /dev/null +++ b/internal/utils/slice.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import "sort" + +// SliceEq return true if two slice have the same elements even if different sort. +func SliceEq(left, right []string) bool { + if len(left) != len(right) { + return false + } + sort.Sort(sort.StringSlice(left)) + sort.Sort(sort.StringSlice(right)) + for i := 0; i < len(left); i++ { + if left[i] != right[i] { + return false + } + } + return true +} diff --git a/internal/utils/sql.go b/internal/utils/sql.go new file mode 100644 index 0000000..5e68c4a --- /dev/null +++ b/internal/utils/sql.go @@ -0,0 +1,19 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IsSubQuery(tbName string) bool { + const selStr = "select" + if len(tbName) <= len(selStr)+1 { + return false + } + + return strings.EqualFold(tbName[:len(selStr)], selStr) || + strings.EqualFold(tbName[:len(selStr)+1], "("+selStr) +} diff --git a/internal/utils/strings.go b/internal/utils/strings.go new file mode 100644 index 0000000..b5dc37b --- /dev/null +++ b/internal/utils/strings.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IndexNoCase(s, sep string) int { + return strings.Index(strings.ToLower(s), strings.ToLower(sep)) +} + +func SplitNoCase(s, sep string) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.Split(s, s[idx:idx+len(sep)]) +} + +func SplitNNoCase(s, sep string, n int) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.SplitN(s, s[idx:idx+len(sep)], n) +} + diff --git a/uuid.go b/internal/utils/uuid.go similarity index 99% rename from uuid.go rename to internal/utils/uuid.go index d52a73b..3d94187 100644 --- a/uuid.go +++ b/internal/utils/uuid.go @@ -1,4 +1,4 @@ -package xorm +package utils import ( "bytes" diff --git a/internal/utils/zero.go b/internal/utils/zero.go new file mode 100644 index 0000000..8f033c6 --- /dev/null +++ b/internal/utils/zero.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "reflect" + "time" +) + +type Zeroable interface { + IsZero() bool +} + +var nilTime *time.Time + +// IsZero returns false if k is nil or has a zero value +func IsZero(k interface{}) bool { + if k == nil { + return true + } + + switch k.(type) { + case int: + return k.(int) == 0 + case int8: + return k.(int8) == 0 + case int16: + return k.(int16) == 0 + case int32: + return k.(int32) == 0 + case int64: + return k.(int64) == 0 + case uint: + return k.(uint) == 0 + case uint8: + return k.(uint8) == 0 + case uint16: + return k.(uint16) == 0 + case uint32: + return k.(uint32) == 0 + case uint64: + return k.(uint64) == 0 + case float32: + return k.(float32) == 0 + case float64: + return k.(float64) == 0 + case bool: + return k.(bool) == false + case string: + return k.(string) == "" + case *time.Time: + return k.(*time.Time) == nilTime || IsTimeZero(*k.(*time.Time)) + case time.Time: + return IsTimeZero(k.(time.Time)) + case Zeroable: + return k.(Zeroable) == nil || k.(Zeroable).IsZero() + case reflect.Value: // for go version less than 1.13 because reflect.Value has no method IsZero + return IsValueZero(k.(reflect.Value)) + } + + return IsValueZero(reflect.ValueOf(k)) +} + +var zeroType = reflect.TypeOf((*Zeroable)(nil)).Elem() + +func IsValueZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice: + return v.IsNil() + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + return v.Int() == 0 + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + return v.Uint() == 0 + case reflect.String: + return v.Len() == 0 + case reflect.Ptr: + if v.IsNil() { + return true + } + return IsValueZero(v.Elem()) + case reflect.Struct: + return IsStructZero(v) + case reflect.Array: + return IsArrayZero(v) + } + return false +} + +func IsStructZero(v reflect.Value) bool { + if !v.IsValid() || v.NumField() == 0 { + return true + } + + if v.Type().Implements(zeroType) { + f := v.MethodByName("IsZero") + if f.IsValid() { + res := f.Call(nil) + return len(res) == 1 && res[0].Bool() + } + } + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + switch field.Kind() { + case reflect.Ptr: + field = field.Elem() + fallthrough + case reflect.Struct: + if !IsStructZero(field) { + return false + } + default: + if field.CanInterface() && !IsZero(field.Interface()) { + return false + } + } + } + return true +} + +func IsArrayZero(v reflect.Value) bool { + if !v.IsValid() || v.Len() == 0 { + return true + } + + for i := 0; i < v.Len(); i++ { + if !IsZero(v.Index(i).Interface()) { + return false + } + } + + return true +} + +const ( + ZeroTime0 = "0000-00-00 00:00:00" + ZeroTime1 = "0001-01-01 00:00:00" +) + +func IsTimeZero(t time.Time) bool { + return t.IsZero() || t.Format("2006-01-02 15:04:05") == ZeroTime0 || + t.Format("2006-01-02 15:04:05") == ZeroTime1 +} diff --git a/internal/utils/zero_test.go b/internal/utils/zero_test.go new file mode 100644 index 0000000..a5f4912 --- /dev/null +++ b/internal/utils/zero_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type MyInt int +type ZeroStruct struct{} + +func TestZero(t *testing.T) { + var zeroValues = []interface{}{ + int8(0), + int16(0), + int(0), + int32(0), + int64(0), + uint8(0), + uint16(0), + uint(0), + uint32(0), + uint64(0), + MyInt(0), + reflect.ValueOf(0), + nil, + time.Time{}, + &time.Time{}, + nilTime, + ZeroStruct{}, + &ZeroStruct{}, + } + + for _, v := range zeroValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsZero(v)) + }) + } +} + +func TestIsValueZero(t *testing.T) { + var zeroReflectValues = []reflect.Value{ + reflect.ValueOf(int8(0)), + reflect.ValueOf(int16(0)), + reflect.ValueOf(int(0)), + reflect.ValueOf(int32(0)), + reflect.ValueOf(int64(0)), + reflect.ValueOf(uint8(0)), + reflect.ValueOf(uint16(0)), + reflect.ValueOf(uint(0)), + reflect.ValueOf(uint32(0)), + reflect.ValueOf(uint64(0)), + reflect.ValueOf(MyInt(0)), + reflect.ValueOf(time.Time{}), + reflect.ValueOf(&time.Time{}), + reflect.ValueOf(nilTime), + reflect.ValueOf(ZeroStruct{}), + reflect.ValueOf(&ZeroStruct{}), + } + + for _, v := range zeroReflectValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsValueZero(v)) + }) + } +} diff --git a/logger.go b/log/logger.go similarity index 64% rename from logger.go rename to log/logger.go index 0b524cf..eeb6369 100644 --- a/logger.go +++ b/log/logger.go @@ -2,26 +2,56 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package log import ( "fmt" "io" "log" +) + +// LogLevel defines a log level +type LogLevel int - "github.com/xormplus/core" +// enumerate all LogLevels +const ( + // !nashtsai! following level also match syslog.Priority value + LOG_DEBUG LogLevel = iota + LOG_INFO + LOG_WARNING + LOG_ERR + LOG_OFF + LOG_UNKNOWN ) // default log options const ( DEFAULT_LOG_PREFIX = "[xorm]" DEFAULT_LOG_FLAG = log.Ldate | log.Lmicroseconds - DEFAULT_LOG_LEVEL = core.LOG_DEBUG + DEFAULT_LOG_LEVEL = LOG_DEBUG ) -var _ core.ILogger = DiscardLogger{} +// Logger is a logger interface +type Logger interface { + Debug(v ...interface{}) + Debugf(format string, v ...interface{}) + Error(v ...interface{}) + Errorf(format string, v ...interface{}) + Info(v ...interface{}) + Infof(format string, v ...interface{}) + Warn(v ...interface{}) + Warnf(format string, v ...interface{}) + + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} + +var _ Logger = DiscardLogger{} -// DiscardLogger don't log implementation for core.ILogger +// DiscardLogger don't log implementation for ILogger type DiscardLogger struct{} // Debug empty implementation @@ -49,12 +79,12 @@ func (DiscardLogger) Warn(v ...interface{}) {} func (DiscardLogger) Warnf(format string, v ...interface{}) {} // Level empty implementation -func (DiscardLogger) Level() core.LogLevel { - return core.LOG_UNKNOWN +func (DiscardLogger) Level() LogLevel { + return LOG_UNKNOWN } // SetLevel empty implementation -func (DiscardLogger) SetLevel(l core.LogLevel) {} +func (DiscardLogger) SetLevel(l LogLevel) {} // ShowSQL empty implementation func (DiscardLogger) ShowSQL(show ...bool) {} @@ -64,17 +94,17 @@ func (DiscardLogger) IsShowSQL() bool { return false } -// SimpleLogger is the default implment of core.ILogger +// SimpleLogger is the default implment of ILogger type SimpleLogger struct { DEBUG *log.Logger ERR *log.Logger INFO *log.Logger WARN *log.Logger - level core.LogLevel + level LogLevel showSQL bool } -var _ core.ILogger = &SimpleLogger{} +var _ Logger = &SimpleLogger{} // NewSimpleLogger use a special io.Writer as logger output func NewSimpleLogger(out io.Writer) *SimpleLogger { @@ -87,7 +117,7 @@ func NewSimpleLogger2(out io.Writer, prefix string, flag int) *SimpleLogger { } // NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel -func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *SimpleLogger { +func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *SimpleLogger { return &SimpleLogger{ DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag), ERR: log.New(out, fmt.Sprintf("%s [error] ", prefix), flag), @@ -97,82 +127,82 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) * } } -// Error implement core.ILogger +// Error implement ILogger func (s *SimpleLogger) Error(v ...interface{}) { - if s.level <= core.LOG_ERR { - s.ERR.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_ERR { + s.ERR.Output(2, fmt.Sprintln(v...)) } return } -// Errorf implement core.ILogger +// Errorf implement ILogger func (s *SimpleLogger) Errorf(format string, v ...interface{}) { - if s.level <= core.LOG_ERR { + if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintf(format, v...)) } return } -// Debug implement core.ILogger +// Debug implement ILogger func (s *SimpleLogger) Debug(v ...interface{}) { - if s.level <= core.LOG_DEBUG { - s.DEBUG.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_DEBUG { + s.DEBUG.Output(2, fmt.Sprintln(v...)) } return } -// Debugf implement core.ILogger +// Debugf implement ILogger func (s *SimpleLogger) Debugf(format string, v ...interface{}) { - if s.level <= core.LOG_DEBUG { + if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } return } -// Info implement core.ILogger +// Info implement ILogger func (s *SimpleLogger) Info(v ...interface{}) { - if s.level <= core.LOG_INFO { - s.INFO.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_INFO { + s.INFO.Output(2, fmt.Sprintln(v...)) } return } -// Infof implement core.ILogger +// Infof implement ILogger func (s *SimpleLogger) Infof(format string, v ...interface{}) { - if s.level <= core.LOG_INFO { + if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintf(format, v...)) } return } -// Warn implement core.ILogger +// Warn implement ILogger func (s *SimpleLogger) Warn(v ...interface{}) { - if s.level <= core.LOG_WARNING { - s.WARN.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_WARNING { + s.WARN.Output(2, fmt.Sprintln(v...)) } return } -// Warnf implement core.ILogger +// Warnf implement ILogger func (s *SimpleLogger) Warnf(format string, v ...interface{}) { - if s.level <= core.LOG_WARNING { + if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintf(format, v...)) } return } -// Level implement core.ILogger -func (s *SimpleLogger) Level() core.LogLevel { +// Level implement ILogger +func (s *SimpleLogger) Level() LogLevel { return s.level } -// SetLevel implement core.ILogger -func (s *SimpleLogger) SetLevel(l core.LogLevel) { +// SetLevel implement ILogger +func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l return } -// ShowSQL implement core.ILogger +// ShowSQL implement ILogger func (s *SimpleLogger) ShowSQL(show ...bool) { if len(show) == 0 { s.showSQL = true @@ -181,7 +211,7 @@ func (s *SimpleLogger) ShowSQL(show ...bool) { s.showSQL = show[0] } -// IsShowSQL implement core.ILogger +// IsShowSQL implement ILogger func (s *SimpleLogger) IsShowSQL() bool { return s.showSQL } diff --git a/log/logger_context.go b/log/logger_context.go new file mode 100644 index 0000000..c2d94dc --- /dev/null +++ b/log/logger_context.go @@ -0,0 +1,121 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package log + +import ( + "context" + "fmt" + "time" +) + +// LogContext represents a log context +type LogContext struct { + Ctx context.Context + SQL string // log content or SQL + Args []interface{} // if it's a SQL, it's the arguments + ExecuteTime time.Duration + Err error // SQL executed error +} + +// SQLLogger represents an interface to log SQL +type SQLLogger interface { + BeforeSQL(context LogContext) // only invoked when IsShowSQL is true + AfterSQL(context LogContext) // only invoked when IsShowSQL is true +} + +// ContextLogger represents a logger interface with context +type ContextLogger interface { + SQLLogger + + Debugf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) + Infof(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} + +var ( + _ ContextLogger = &LoggerAdapter{} +) + +// enumerate all the context keys +var ( + SessionIDKey = "__xorm_session_id" + SessionShowSQLKey = "__xorm_show_sql" +) + +// LoggerAdapter wraps a Logger interafce as LoggerContext interface +type LoggerAdapter struct { + logger Logger +} + +// NewLoggerAdapter creates an adapter for old xorm logger interface +func NewLoggerAdapter(logger Logger) ContextLogger { + return &LoggerAdapter{ + logger: logger, + } +} + +// BeforeSQL implements ContextLogger +func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} + +// AfterSQL implements ContextLogger +func (l *LoggerAdapter) AfterSQL(ctx LogContext) { + var sessionPart string + v := ctx.Ctx.Value(SessionIDKey) + if key, ok := v.(string); ok { + sessionPart = fmt.Sprintf(" [%s]", key) + } + if ctx.ExecuteTime > 0 { + l.logger.Infof("[SQL]%s %s %v - %v", sessionPart, ctx.SQL, ctx.Args, ctx.ExecuteTime) + } else { + l.logger.Infof("[SQL]%s %s %v", sessionPart, ctx.SQL, ctx.Args) + } +} + +// Debugf implements ContextLogger +func (l *LoggerAdapter) Debugf(format string, v ...interface{}) { + l.logger.Debugf(format, v...) +} + +// Errorf implements ContextLogger +func (l *LoggerAdapter) Errorf(format string, v ...interface{}) { + l.logger.Errorf(format, v...) +} + +// Infof implements ContextLogger +func (l *LoggerAdapter) Infof(format string, v ...interface{}) { + l.logger.Infof(format, v...) +} + +// Warnf implements ContextLogger +func (l *LoggerAdapter) Warnf(format string, v ...interface{}) { + l.logger.Warnf(format, v...) +} + +// Level implements ContextLogger +func (l *LoggerAdapter) Level() LogLevel { + return l.logger.Level() +} + +// SetLevel implements ContextLogger +func (l *LoggerAdapter) SetLevel(lv LogLevel) { + l.logger.SetLevel(lv) +} + +// ShowSQL implements ContextLogger +func (l *LoggerAdapter) ShowSQL(show ...bool) { + l.logger.ShowSQL(show...) +} + +// IsShowSQL implements ContextLogger +func (l *LoggerAdapter) IsShowSQL() bool { + return l.logger.IsShowSQL() +} diff --git a/syslogger.go b/log/syslogger.go similarity index 88% rename from syslogger.go rename to log/syslogger.go index 320d814..0b3e381 100644 --- a/syslogger.go +++ b/log/syslogger.go @@ -4,16 +4,14 @@ // +build !windows,!nacl,!plan9 -package xorm +package log import ( "fmt" "log/syslog" - - "github.com/xormplus/core" ) -var _ core.ILogger = &SyslogLogger{} +var _ Logger = &SyslogLogger{} // SyslogLogger will be depricated type SyslogLogger struct { @@ -21,7 +19,7 @@ type SyslogLogger struct { showSQL bool } -// NewSyslogLogger implements core.ILogger +// NewSyslogLogger implements Logger func NewSyslogLogger(w *syslog.Writer) *SyslogLogger { return &SyslogLogger{w: w} } @@ -67,12 +65,12 @@ func (s *SyslogLogger) Warnf(format string, v ...interface{}) { } // Level shows log level -func (s *SyslogLogger) Level() core.LogLevel { - return core.LOG_UNKNOWN +func (s *SyslogLogger) Level() LogLevel { + return LOG_UNKNOWN } // SetLevel always return error, as current log/syslog package doesn't allow to set priority level after syslog.Writer created -func (s *SyslogLogger) SetLevel(l core.LogLevel) {} +func (s *SyslogLogger) SetLevel(l LogLevel) {} // ShowSQL set if logging SQL func (s *SyslogLogger) ShowSQL(show ...bool) { diff --git a/migrate/migrate.go b/migrate/migrate.go index b56d478..abe16eb 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -13,7 +13,7 @@ type MigrateFunc func(*xorm.Engine) error // RollbackFunc is the func signature for rollbacking. type RollbackFunc func(*xorm.Engine) error -// InitSchemaFunc is the func signature for initializing the schema. +// InitSchemaFunc is the func signature for initializing the schemas. type InitSchemaFunc func(*xorm.Engine) error // Options define options for all migrations. @@ -34,7 +34,7 @@ type Migration struct { Rollback RollbackFunc } -// Migrate represents a collection of all migrations of a database schema. +// Migrate represents a collection of all migrations of a database schemas. type Migrate struct { db *xorm.Engine options *Options diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index a632cbf..1e4622c 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -7,8 +7,8 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" "github.com/xormplus/xorm" - "gopkg.in/stretchr/testify.v1/assert" ) type Person struct { diff --git a/names/mapper.go b/names/mapper.go new file mode 100644 index 0000000..4aaf084 --- /dev/null +++ b/names/mapper.go @@ -0,0 +1,258 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "strings" + "sync" +) + +// Mapper represents a name convertation between struct's fields name and table's column name +type Mapper interface { + Obj2Table(string) string + Table2Obj(string) string +} + +type CacheMapper struct { + oriMapper Mapper + obj2tableCache map[string]string + obj2tableMutex sync.RWMutex + table2objCache map[string]string + table2objMutex sync.RWMutex +} + +func NewCacheMapper(mapper Mapper) *CacheMapper { + return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), + table2objCache: make(map[string]string), + } +} + +func (m *CacheMapper) Obj2Table(o string) string { + m.obj2tableMutex.RLock() + t, ok := m.obj2tableCache[o] + m.obj2tableMutex.RUnlock() + if ok { + return t + } + + t = m.oriMapper.Obj2Table(o) + m.obj2tableMutex.Lock() + m.obj2tableCache[o] = t + m.obj2tableMutex.Unlock() + return t +} + +func (m *CacheMapper) Table2Obj(t string) string { + m.table2objMutex.RLock() + o, ok := m.table2objCache[t] + m.table2objMutex.RUnlock() + if ok { + return o + } + + o = m.oriMapper.Table2Obj(t) + m.table2objMutex.Lock() + m.table2objCache[t] = o + m.table2objMutex.Unlock() + return o +} + +// SameMapper implements IMapper and provides same name between struct and +// database table +type SameMapper struct { +} + +func (m SameMapper) Obj2Table(o string) string { + return o +} + +func (m SameMapper) Table2Obj(t string) string { + return t +} + +// SnakeMapper implements IMapper and provides name transaltion between +// struct and database table +type SnakeMapper struct { +} + +func snakeCasedName(name string) string { + newstr := make([]rune, 0) + for idx, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if idx > 0 { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Obj2Table(name string) string { + return snakeCasedName(name) +} + +func titleCasedName(name string) string { + newstr := make([]rune, 0) + upNextChar := true + + name = strings.ToLower(name) + + for _, chr := range name { + switch { + case upNextChar: + upNextChar = false + if 'a' <= chr && chr <= 'z' { + chr -= ('a' - 'A') + } + case chr == '_': + upNextChar = true + continue + } + + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Table2Obj(name string) string { + return titleCasedName(name) +} + +// GonicMapper implements IMapper. It will consider initialisms when mapping names. +// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid +type GonicMapper map[string]bool + +func isASCIIUpper(r rune) bool { + return 'A' <= r && r <= 'Z' +} + +func toASCIIUpper(r rune) rune { + if 'a' <= r && r <= 'z' { + r -= ('a' - 'A') + } + return r +} + +func gonicCasedName(name string) string { + newstr := make([]rune, 0, len(name)+3) + for idx, chr := range name { + if isASCIIUpper(chr) && idx > 0 { + if !isASCIIUpper(newstr[len(newstr)-1]) { + newstr = append(newstr, '_') + } + } + + if !isASCIIUpper(chr) && idx > 1 { + l := len(newstr) + if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) { + newstr = append(newstr, newstr[l-1]) + newstr[l-1] = '_' + } + } + + newstr = append(newstr, chr) + } + return strings.ToLower(string(newstr)) +} + +func (mapper GonicMapper) Obj2Table(name string) string { + return gonicCasedName(name) +} + +func (mapper GonicMapper) Table2Obj(name string) string { + newstr := make([]rune, 0) + + name = strings.ToLower(name) + parts := strings.Split(name, "_") + + for _, p := range parts { + _, isInitialism := mapper[strings.ToUpper(p)] + for i, r := range p { + if i == 0 || isInitialism { + r = toASCIIUpper(r) + } + newstr = append(newstr, r) + } + } + + return string(newstr) +} + +// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint +var LintGonicMapper = GonicMapper{ + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "DNS": true, + "EOF": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ID": true, + "IP": true, + "JSON": true, + "LHS": true, + "QPS": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SSH": true, + "TLS": true, + "TTL": true, + "UI": true, + "UID": true, + "UUID": true, + "URI": true, + "URL": true, + "UTF8": true, + "VM": true, + "XML": true, + "XSRF": true, + "XSS": true, +} + +// PrefixMapper provides prefix table name support +type PrefixMapper struct { + Mapper Mapper + Prefix string +} + +func (mapper PrefixMapper) Obj2Table(name string) string { + return mapper.Prefix + mapper.Mapper.Obj2Table(name) +} + +func (mapper PrefixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) +} + +func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { + return PrefixMapper{mapper, prefix} +} + +// SuffixMapper provides suffix table name support +type SuffixMapper struct { + Mapper Mapper + Suffix string +} + +func (mapper SuffixMapper) Obj2Table(name string) string { + return mapper.Mapper.Obj2Table(name) + mapper.Suffix +} + +func (mapper SuffixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) +} + +func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { + return SuffixMapper{mapper, suffix} +} diff --git a/names/mapper_test.go b/names/mapper_test.go new file mode 100644 index 0000000..0edfd2a --- /dev/null +++ b/names/mapper_test.go @@ -0,0 +1,49 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "testing" +) + +func TestGonicMapperFromObj(t *testing.T) { + testCases := map[string]string{ + "HTTPLib": "http_lib", + "id": "id", + "ID": "id", + "IDa": "i_da", + "iDa": "i_da", + "IDAa": "id_aa", + "aID": "a_id", + "aaID": "aa_id", + "aaaID": "aaa_id", + "MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name", + } + + for in, expected := range testCases { + out := gonicCasedName(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} + +func TestGonicMapperToObj(t *testing.T) { + testCases := map[string]string{ + "http_lib": "HTTPLib", + "id": "ID", + "ida": "Ida", + "id_aa": "IDAa", + "aa_id": "AaID", + "my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName", + } + + for in, expected := range testCases { + out := LintGonicMapper.Table2Obj(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} diff --git a/names/table_name.go b/names/table_name.go new file mode 100644 index 0000000..0afb1ae --- /dev/null +++ b/names/table_name.go @@ -0,0 +1,56 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "reflect" + "sync" +) + +// TableName table name interface to define customerize table name +type TableName interface { + TableName() string +} + +var ( + tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tvCache sync.Map +) + +func GetTableName(mapper Mapper, v reflect.Value) string { + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableName) { + return v1.Interface().(TableName).TableName() + } + } else { + name, ok := tvCache.Load(v.Type()) + if ok { + if name.(string) != "" { + return name.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableName) { + tableName := v2.Interface().(TableName).TableName() + tvCache.Store(v.Type(), tableName) + return tableName + } + + tvCache.Store(v.Type(), "") + } + } + + return mapper.Obj2Table(v.Type().Name()) +} diff --git a/names/table_name_test.go b/names/table_name_test.go new file mode 100644 index 0000000..76da413 --- /dev/null +++ b/names/table_name_test.go @@ -0,0 +1,140 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type Userinfo struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +type Userdetail struct { + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` +} + +type MyGetCustomTableImpletation struct { + Id int64 `json:"id"` + Name string `json:"name"` +} + +const getCustomTableName = "GetCustomTableInterface" + +func (MyGetCustomTableImpletation) TableName() string { + return getCustomTableName +} + +type TestTableNameStruct struct{} + +const getTestTableName = "my_test_table_name_struct" + +func (t *TestTableNameStruct) TableName() string { + return getTestTableName +} + +func TestGetTableName(t *testing.T) { + var kases = []struct { + mapper Mapper + v reflect.Value + expectedTableName string + }{ + { + SnakeMapper{}, + reflect.ValueOf(new(Userinfo)), + "userinfo", + }, + { + SnakeMapper{}, + reflect.ValueOf(Userinfo{}), + "userinfo", + }, + { + SameMapper{}, + reflect.ValueOf(new(Userinfo)), + "Userinfo", + }, + { + SameMapper{}, + reflect.ValueOf(Userinfo{}), + "Userinfo", + }, + { + SnakeMapper{}, + reflect.ValueOf(new(MyGetCustomTableImpletation)), + getCustomTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(MyGetCustomTableImpletation{}), + getCustomTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(new(TestTableNameStruct)), + new(TestTableNameStruct).TableName(), + }, + { + SnakeMapper{}, + reflect.ValueOf(new(TestTableNameStruct)), + getTestTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(TestTableNameStruct{}), + getTestTableName, + }, + } + + for _, kase := range kases { + assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v)) + } +} + +type OAuth2Application struct { +} + +// TableName sets the table name to `oauth2_application` +func (app *OAuth2Application) TableName() string { + return "oauth2_application" +} + +func TestGonicMapperCustomTable(t *testing.T) { + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(new(OAuth2Application)))) + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(OAuth2Application{}))) +} + +type MyTable struct { + Idx int +} + +func (t *MyTable) TableName() string { + return fmt.Sprintf("mytable_%d", t.Idx) +} + +func TestMyTable(t *testing.T) { + var table MyTable + for i := 0; i < 10; i++ { + table.Idx = i + assert.EqualValues(t, fmt.Sprintf("mytable_%d", i), GetTableName(SameMapper{}, reflect.ValueOf(&table))) + } +} diff --git a/processors.go b/processors.go index dcd9c6a..8697e30 100644 --- a/processors.go +++ b/processors.go @@ -76,3 +76,69 @@ func (session *Session) executeProcessors() error { } return nil } + +func cleanupProcessorsClosures(slices *[]func(interface{})) { + if len(*slices) > 0 { + *slices = make([]func(interface{}), 0) + } +} + +func executeBeforeClosures(session *Session, bean interface{}) { + // handle before delete processors + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) +} + +func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) { + if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { + for ii, key := range fields { + b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) + } + } +} + +func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) { + if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { + for ii, key := range fields { + b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) + } + } +} + +func buildAfterProcessors(session *Session, bean interface{}) { + // handle afterClosures + for _, closure := range session.afterClosures { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + closure(bean) + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad() + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadSessionProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad(sess) + return nil + }, + session: session, + bean: bean, + }) + } +} diff --git a/rows.go b/rows.go index a61a0ac..36ed5e6 100644 --- a/rows.go +++ b/rows.go @@ -6,10 +6,13 @@ package xorm import ( "database/sql" + "errors" "fmt" "reflect" - "github.com/xormplus/core" + "github.com/xormplus/builder" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/internal/utils" ) // Rows rows wrapper a rows to @@ -29,7 +32,14 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefBean(bean); err != nil { + beanValue := reflect.ValueOf(bean) + if beanValue.Kind() != reflect.Ptr { + return nil, errors.New("needs a pointer to a value") + } else if beanValue.Elem().Kind() == reflect.Ptr { + return nil, errors.New("a pointer to a pointer is not allowed") + } + + if err = rows.session.statement.SetRefBean(bean); err != nil { return nil, err } @@ -38,12 +48,37 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } if rows.session.statement.RawSQL == "" { - sqlStr, args, err = rows.session.statement.genGetSQL(bean) + var autoCond builder.Cond + var addedTableName = (len(session.statement.JoinStr) > 0) + var table = rows.session.statement.RefTable + + if !session.statement.NoAutoCondition { + var err error + autoCond, err = session.statement.BuildConds(table, bean, true, true, false, true, addedTableName) + if err != nil { + return nil, err + } + } else { + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + var colName = session.engine.Quote(col.Name) + if addedTableName { + var nm = session.statement.TableName() + if len(session.statement.TableAlias) > 0 { + nm = session.statement.TableAlias + } + colName = session.engine.Quote(nm) + "." + colName + } + + autoCond = session.statement.CondDeleted(col) + } + } + + sqlStr, args, err = rows.session.statement.GenFindSQL(autoCond) if err != nil { return nil, err } } else { - sqlStr = rows.session.statement.RawSQL + sqlStr = rows.session.statement.GenRawSQL() args = rows.session.statement.RawParams } @@ -84,7 +119,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - if err := rows.session.statement.setRefBean(bean); err != nil { + if err := rows.session.statement.SetRefBean(bean); err != nil { return err } @@ -98,7 +133,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/schemas/column.go b/schemas/column.go new file mode 100644 index 0000000..418629a --- /dev/null +++ b/schemas/column.go @@ -0,0 +1,117 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +const ( + TWOSIDES = iota + 1 + ONLYTODB + ONLYFROMDB +) + +// Column defines database column +type Column struct { + Name string + TableName string + FieldName string // Avaiable only when parsed from a struct + SQLType SQLType + IsJSON bool + Length int + Length2 int + Nullable bool + Default string + Indexes map[string]int + IsPrimaryKey bool + IsAutoIncrement bool + MapType int + IsCreated bool + IsUpdated bool + IsDeleted bool + IsCascade bool + IsVersion bool + DefaultIsEmpty bool // false means column has no default set, but not default value is empty + EnumOptions map[string]int + SetOptions map[string]int + DisableTimeZone bool + TimeZone *time.Location // column specified time zone + Comment string +} + +// NewColumn creates a new column +func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { + return &Column{ + Name: name, + TableName: "", + FieldName: fieldName, + SQLType: sqlType, + Length: len1, + Length2: len2, + Nullable: nullable, + Default: "", + Indexes: make(map[string]int), + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: TWOSIDES, + IsCreated: false, + IsUpdated: false, + IsDeleted: false, + IsCascade: false, + IsVersion: false, + DefaultIsEmpty: true, // default should be no default + EnumOptions: make(map[string]int), + Comment: "", + } +} + +// ValueOf returns column's filed of struct's value +func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + return col.ValueOfV(&dataStruct) +} + +// ValueOfV returns column's filed of struct's value accept reflevt value +func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { + var fieldValue reflect.Value + fieldPath := strings.Split(col.FieldName, ".") + + if dataStruct.Type().Kind() == reflect.Map { + keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) + fieldValue = dataStruct.MapIndex(keyValue) + return &fieldValue, nil + } else if dataStruct.Type().Kind() == reflect.Interface { + structValue := reflect.ValueOf(dataStruct.Interface()) + dataStruct = &structValue + } + + level := len(fieldPath) + fieldValue = dataStruct.FieldByName(fieldPath[0]) + for i := 0; i < level-1; i++ { + if !fieldValue.IsValid() { + break + } + if fieldValue.Kind() == reflect.Struct { + fieldValue = fieldValue.FieldByName(fieldPath[i+1]) + } else if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) + } else { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + } + + if !fieldValue.IsValid() { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + + return &fieldValue, nil +} diff --git a/schemas/index.go b/schemas/index.go new file mode 100644 index 0000000..9541250 --- /dev/null +++ b/schemas/index.go @@ -0,0 +1,72 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "fmt" + "strings" +) + +// enumerate all index types +const ( + IndexType = iota + 1 + UniqueType +) + +// Index represents a database index +type Index struct { + IsRegular bool + Name string + Type int + Cols []string +} + +// NewIndex new an index object +func NewIndex(name string, indexType int) *Index { + return &Index{true, name, indexType, make([]string, 0)} +} + +func (index *Index) XName(tableName string) string { + if !strings.HasPrefix(index.Name, "UQE_") && + !strings.HasPrefix(index.Name, "IDX_") { + tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") + tableName = tableParts[len(tableParts)-1] + if index.Type == UniqueType { + return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } + return fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + return index.Name +} + +// AddColumn add columns which will be composite index +func (index *Index) AddColumn(cols ...string) { + for _, col := range cols { + index.Cols = append(index.Cols, col) + } +} + +func (index *Index) Equal(dst *Index) bool { + if index.Type != dst.Type { + return false + } + if len(index.Cols) != len(dst.Cols) { + return false + } + + for i := 0; i < len(index.Cols); i++ { + var found bool + for j := 0; j < len(dst.Cols); j++ { + if index.Cols[i] == dst.Cols[j] { + found = true + break + } + } + if !found { + return false + } + } + return true +} diff --git a/schemas/pk.go b/schemas/pk.go new file mode 100644 index 0000000..aff9606 --- /dev/null +++ b/schemas/pk.go @@ -0,0 +1,41 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "bytes" + "encoding/gob" + + "github.com/xormplus/xorm/internal/utils" +) + +type PK []interface{} + +func NewPK(pks ...interface{}) *PK { + p := PK(pks) + return &p +} + +func (p *PK) IsZero() bool { + for _, k := range *p { + if utils.IsZero(k) { + return true + } + } + return false +} + +func (p *PK) ToString() (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(*p) + return buf.String(), err +} + +func (p *PK) FromString(content string) error { + dec := gob.NewDecoder(bytes.NewBufferString(content)) + err := dec.Decode(p) + return err +} diff --git a/schemas/pk_test.go b/schemas/pk_test.go new file mode 100644 index 0000000..a88b70d --- /dev/null +++ b/schemas/pk_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "reflect" + "testing" +) + +func TestPK(t *testing.T) { + p := NewPK(1, 3, "string") + str, err := p.ToString() + if err != nil { + t.Error(err) + } + t.Log(str) + + s := &PK{} + err = s.FromString(str) + if err != nil { + t.Error(err) + } + t.Log(s) + + if len(*p) != len(*s) { + t.Fatal("p", *p, "should be equal", *s) + } + + for i, ori := range *p { + if ori != (*s)[i] { + t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i])) + } + } +} diff --git a/schemas/quote.go b/schemas/quote.go new file mode 100644 index 0000000..c44abe2 --- /dev/null +++ b/schemas/quote.go @@ -0,0 +1,240 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" +) + +// Quoter represents a quoter to the SQL table name and column name +type Quoter struct { + Prefix byte + Suffix byte + IsReserved func(string) bool +} + +var ( + // AlwaysFalseReverse always think it's not a reverse word + AlwaysNoReserve = func(string) bool { return false } + + // AlwaysReverse always reverse the word + AlwaysReserve = func(string) bool { return true } + + // CommanQuoteMark represnets the common quote mark + CommanQuoteMark byte = '`' + + // CommonQuoter represetns a common quoter + CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} +) + +func (q Quoter) IsEmpty() bool { + return q.Prefix == 0 && q.Suffix == 0 +} + +func (q Quoter) Quote(s string) string { + var buf strings.Builder + q.QuoteTo(&buf, s) + return buf.String() +} + +// Trim removes quotes from s +func (q Quoter) Trim(s string) string { + if len(s) < 2 { + return s + } + + var buf strings.Builder + for i := 0; i < len(s); i++ { + switch { + case i == 0 && s[i] == q.Prefix: + case i == len(s)-1 && s[i] == q.Suffix: + case s[i] == q.Suffix && s[i+1] == '.': + case s[i] == q.Prefix && s[i-1] == '.': + default: + buf.WriteByte(s[i]) + } + } + return buf.String() +} + +func (q Quoter) Join(a []string, sep string) string { + var b strings.Builder + q.JoinWrite(&b, a, sep) + return b.String() +} + +func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { + if len(a) == 0 { + return nil + } + + n := len(sep) * (len(a) - 1) + for i := 0; i < len(a); i++ { + n += len(a[i]) + } + + b.Grow(n) + for i, s := range a { + if i > 0 { + if _, err := b.WriteString(sep); err != nil { + return err + } + } + if s != "*" { + q.QuoteTo(b, strings.TrimSpace(s)) + } + } + return nil +} + +func findWord(v string, start int) int { + for j := start; j < len(v); j++ { + switch v[j] { + case '.', ' ': + return j + } + } + return len(v) +} + +func findStart(value string, start int) int { + if value[start] == '.' { + return start + 1 + } + if value[start] != ' ' { + return start + } + + var k = -1 + for j := start; j < len(value); j++ { + if value[j] != ' ' { + k = j + break + } + } + if k == -1 { + return len(value) + } + + if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { + k = k + 2 + } + + for j := k; j < len(value); j++ { + if value[j] != ' ' { + return j + } + } + return len(value) +} + +func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { + var realWord = word + if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || + (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { + realWord = word[1 : len(word)-1] + } + + if q.IsEmpty() { + _, err := buf.WriteString(realWord) + return err + } + + isReserved := q.IsReserved(realWord) + if isReserved { + if err := buf.WriteByte(q.Prefix); err != nil { + return err + } + } + if _, err := buf.WriteString(realWord); err != nil { + return err + } + if isReserved { + return buf.WriteByte(q.Suffix) + } + + return nil +} + +// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] +// name -> [name] +// `name` -> [name] +// [name] -> [name] +// schema.name -> [schema].[name] +// `schema`.`name` -> [schema].[name] +// `schema`.name -> [schema].[name] +// schema.`name` -> [schema].[name] +// [schema].name -> [schema].[name] +// schema.[name] -> [schema].[name] +// name AS a -> [name] AS a +// schema.name AS a -> [schema].[name] AS a +func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { + var i int + for i < len(value) { + start := findStart(value, i) + if start > i { + if _, err := buf.WriteString(value[i:start]); err != nil { + return err + } + } + if start == len(value) { + return nil + } + + var nextEnd = findWord(value, start) + if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { + return err + } + i = nextEnd + } + return nil +} + +// Strings quotes a slice of string +func (q Quoter) Strings(s []string) []string { + var res = make([]string, 0, len(s)) + for _, a := range s { + res = append(res, q.Quote(a)) + } + return res +} + +// Replace replaces common quote(`) as the quotes on the sql +func (q Quoter) Replace(sql string) string { + if q.IsEmpty() { + return sql + } + + var buf strings.Builder + buf.Grow(len(sql)) + + var beginSingleQuote bool + for i := 0; i < len(sql); i++ { + if !beginSingleQuote && sql[i] == CommanQuoteMark { + var j = i + 1 + for ; j < len(sql); j++ { + if sql[j] == CommanQuoteMark { + break + } + } + word := sql[i+1 : j] + isReserved := q.IsReserved(word) + if isReserved { + buf.WriteByte(q.Prefix) + } + buf.WriteString(word) + if isReserved { + buf.WriteByte(q.Suffix) + } + i = j + } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteByte(sql[i]) + } + } + return buf.String() +} diff --git a/schemas/quote_test.go b/schemas/quote_test.go new file mode 100644 index 0000000..708b450 --- /dev/null +++ b/schemas/quote_test.go @@ -0,0 +1,181 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAlwaysQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysReserve} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`["mytable"]`, `"mytable"`}, + {"[myschema].[mytable]", "myschema.mytable"}, + {"[myschema].[mytable]", "`myschema`.mytable"}, + {"[myschema].[mytable]", "myschema.`mytable`"}, + {"[myschema].[mytable]", "`myschema`.`mytable`"}, + {"[myschema].[mytable]", `[myschema].mytable`}, + {"[myschema].[mytable]", `myschema.[mytable]`}, + {"[myschema].[mytable]", `[myschema].[mytable]`}, + {`["myschema].[mytable"]`, `"myschema.mytable"`}, + {"[message_user] AS [sender]", "`message_user` AS `sender`"}, + {"[myschema].[mytable] AS [table]", "myschema.mytable AS table"}, + {" [mytable]", " mytable"}, + {" [mytable]", " mytable"}, + {"[mytable] ", "mytable "}, + {"[mytable] ", "mytable "}, + {" [mytable] ", " mytable "}, + {" [mytable] ", " mytable "}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestReversedQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', func(s string) bool { + if s == "mytable" { + return true + } + return false + }} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.[mytable]", "myschema.mytable"}, + {"myschema.[mytable]", "`myschema`.mytable"}, + {"myschema.[mytable]", "myschema.`mytable`"}, + {"myschema.[mytable]", "`myschema`.`mytable`"}, + {"myschema.[mytable]", `[myschema].mytable`}, + {"myschema.[mytable]", `myschema.[mytable]`}, + {"myschema.[mytable]", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.[mytable] AS table", "myschema.mytable AS table"}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestNoQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysNoReserve} + kases = []struct { + expected string + value string + }{ + {"mytable", "mytable"}, + {"mytable", "`mytable`"}, + {"mytable", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.mytable", "myschema.mytable"}, + {"myschema.mytable", "`myschema`.mytable"}, + {"myschema.mytable", "myschema.`mytable`"}, + {"myschema.mytable", "`myschema`.`mytable`"}, + {"myschema.mytable", `[myschema].mytable`}, + {"myschema.mytable", `myschema.[mytable]`}, + {"myschema.mytable", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.mytable AS table", "myschema.mytable AS table"}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestJoin(t *testing.T) { + cols := []string{"f1", "f2", "f3"} + quoter := Quoter{'[', ']', AlwaysReserve} + + assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) + + assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) + + quoter.IsReserved = AlwaysNoReserve + assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) +} + +func TestStrings(t *testing.T) { + cols := []string{"f1", "f2", "t3.f3"} + quoter := Quoter{'[', ']', AlwaysReserve} + + quotedCols := quoter.Strings(cols) + assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) +} + +func TestTrim(t *testing.T) { + var kases = map[string]string{ + "[table_name]": "table_name", + "[schema].[table_name]": "schema.table_name", + } + + for src, dst := range kases { + assert.EqualValues(t, src, CommonQuoter.Trim(src)) + assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src)) + } +} + +func TestReplace(t *testing.T) { + q := Quoter{'[', ']', AlwaysReserve} + var kases = []struct { + source string + expected string + }{ + { + "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + }, + { + "SELECT 'abc```test```''', `a` FROM b", + "SELECT 'abc```test```''', [a] FROM b", + }, + { + "UPDATE table SET `a` = ~ `a`, `b`='abc`'", + "UPDATE table SET [a] = ~ [a], [b]='abc`'", + }, + } + + for _, kase := range kases { + t.Run(kase.source, func(t *testing.T) { + assert.EqualValues(t, kase.expected, q.Replace(kase.source)) + }) + } +} diff --git a/schemas/table.go b/schemas/table.go new file mode 100644 index 0000000..3859699 --- /dev/null +++ b/schemas/table.go @@ -0,0 +1,146 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "reflect" + "strings" +) + +// Table represents a database table +type Table struct { + Name string + Type reflect.Type + columnsSeq []string + columnsMap map[string][]*Column + columns []*Column + Indexes map[string]*Index + PrimaryKeys []string + AutoIncrement string + Created map[string]bool + Updated string + Deleted string + Version string + StoreEngine string + Charset string + Comment string +} + +func NewEmptyTable() *Table { + return NewTable("", nil) +} + +// NewTable creates a new Table object +func NewTable(name string, t reflect.Type) *Table { + return &Table{Name: name, Type: t, + columnsSeq: make([]string, 0), + columns: make([]*Column, 0), + columnsMap: make(map[string][]*Column), + Indexes: make(map[string]*Index), + Created: make(map[string]bool), + PrimaryKeys: make([]string, 0), + } +} + +func (table *Table) Columns() []*Column { + return table.columns +} + +func (table *Table) ColumnsSeq() []string { + return table.columnsSeq +} + +func (table *Table) columnsByName(name string) []*Column { + for k, cols := range table.columnsMap { + if strings.EqualFold(k, name) { + return cols + } + } + return nil +} + +func (table *Table) GetColumn(name string) *Column { + cols := table.columnsByName(name) + if cols != nil { + return cols[0] + } + + return nil +} + +func (table *Table) GetColumnIdx(name string, idx int) *Column { + cols := table.columnsByName(name) + if cols != nil && idx < len(cols) { + return cols[idx] + } + + return nil +} + +// PKColumns reprents all primary key columns +func (table *Table) PKColumns() []*Column { + columns := make([]*Column, len(table.PrimaryKeys)) + for i, name := range table.PrimaryKeys { + columns[i] = table.GetColumn(name) + } + return columns +} + +func (table *Table) ColumnType(name string) reflect.Type { + t, _ := table.Type.FieldByName(name) + return t.Type +} + +func (table *Table) AutoIncrColumn() *Column { + return table.GetColumn(table.AutoIncrement) +} + +func (table *Table) VersionColumn() *Column { + return table.GetColumn(table.Version) +} + +func (table *Table) UpdatedColumn() *Column { + return table.GetColumn(table.Updated) +} + +func (table *Table) DeletedColumn() *Column { + return table.GetColumn(table.Deleted) +} + +// AddColumn adds a column to table +func (table *Table) AddColumn(col *Column) { + table.columnsSeq = append(table.columnsSeq, col.Name) + table.columns = append(table.columns, col) + colName := strings.ToLower(col.Name) + if c, ok := table.columnsMap[colName]; ok { + table.columnsMap[colName] = append(c, col) + } else { + table.columnsMap[colName] = []*Column{col} + } + + if col.IsPrimaryKey { + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + } + if col.IsAutoIncrement { + table.AutoIncrement = col.Name + } + if col.IsCreated { + table.Created[col.Name] = true + } + if col.IsUpdated { + table.Updated = col.Name + } + if col.IsDeleted { + table.Deleted = col.Name + } + if col.IsVersion { + table.Version = col.Name + } +} + +// AddIndex adds an index or an unique to table +func (table *Table) AddIndex(index *Index) { + table.Indexes[index.Name] = index +} diff --git a/schemas/table_test.go b/schemas/table_test.go new file mode 100644 index 0000000..9bf10e3 --- /dev/null +++ b/schemas/table_test.go @@ -0,0 +1,111 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" + "testing" +) + +var testsGetColumn = []struct { + name string + idx int +}{ + {"Id", 0}, + {"Deleted", 0}, + {"Caption", 0}, + {"Code_1", 0}, + {"Code_2", 0}, + {"Code_3", 0}, + {"Parent_Id", 0}, + {"Latitude", 0}, + {"Longitude", 0}, +} + +var table *Table + +func init() { + + table = NewEmptyTable() + + var name string + + for i := 0; i < len(testsGetColumn); i++ { + // as in Table.AddColumn func + name = strings.ToLower(testsGetColumn[i].name) + + table.columnsMap[name] = append(table.columnsMap[name], &Column{}) + } +} + +func TestGetColumn(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + t.Error("Column not found!") + } + } +} + +func TestGetColumnIdx(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + t.Errorf("Column %s with idx %d not found!", test.name, test.idx) + } + } +} + +func BenchmarkGetColumnWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdxWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { + if test.idx < len(c) { + continue + } else { + b.Errorf("Bad idx in: %s, %d", test.name, test.idx) + } + } else { + b.Errorf("Column not found: %s, %d", test.name, test.idx) + } + } + } +} + +func BenchmarkGetColumn(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdx(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + b.Errorf("Column not found:%s, %d", test.name, test.idx) + } + } + } +} diff --git a/schemas/type.go b/schemas/type.go new file mode 100644 index 0000000..89459a4 --- /dev/null +++ b/schemas/type.go @@ -0,0 +1,336 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "reflect" + "sort" + "strings" + "time" +) + +type DBType string + +const ( + POSTGRES DBType = "postgres" + SQLITE DBType = "sqlite3" + MYSQL DBType = "mysql" + MSSQL DBType = "mssql" + ORACLE DBType = "oracle" +) + +// SQLType represents SQL types +type SQLType struct { + Name string + DefaultLength int + DefaultLength2 int +} + +const ( + UNKNOW_TYPE = iota + TEXT_TYPE + BLOB_TYPE + TIME_TYPE + NUMERIC_TYPE + ARRAY_TYPE +) + +func (s *SQLType) IsType(st int) bool { + if t, ok := SqlTypes[s.Name]; ok && t == st { + return true + } + return false +} + +func (s *SQLType) IsText() bool { + return s.IsType(TEXT_TYPE) +} + +func (s *SQLType) IsBlob() bool { + return s.IsType(BLOB_TYPE) +} + +func (s *SQLType) IsTime() bool { + return s.IsType(TIME_TYPE) +} + +func (s *SQLType) IsNumeric() bool { + return s.IsType(NUMERIC_TYPE) +} + +func (s *SQLType) IsArray() bool { + return s.IsType(ARRAY_TYPE) +} + +func (s *SQLType) IsJson() bool { + return s.Name == Json || s.Name == Jsonb +} + +var ( + Bit = "BIT" + TinyInt = "TINYINT" + SmallInt = "SMALLINT" + MediumInt = "MEDIUMINT" + Int = "INT" + Integer = "INTEGER" + BigInt = "BIGINT" + + Enum = "ENUM" + Set = "SET" + + Char = "CHAR" + Varchar = "VARCHAR" + NChar = "NCHAR" + NVarchar = "NVARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + NText = "NTEXT" + Clob = "CLOB" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" + Uuid = "UUID" + UniqueIdentifier = "UNIQUEIDENTIFIER" + SysName = "SYSNAME" + + Date = "DATE" + DateTime = "DATETIME" + SmallDateTime = "SMALLDATETIME" + Time = "TIME" + TimeStamp = "TIMESTAMP" + TimeStampz = "TIMESTAMPZ" + Year = "YEAR" + + Decimal = "DECIMAL" + Numeric = "NUMERIC" + Money = "MONEY" + SmallMoney = "SMALLMONEY" + + Real = "REAL" + Float = "FLOAT" + Double = "DOUBLE" + + Binary = "BINARY" + VarBinary = "VARBINARY" + TinyBlob = "TINYBLOB" + Blob = "BLOB" + MediumBlob = "MEDIUMBLOB" + LongBlob = "LONGBLOB" + Bytea = "BYTEA" + + Bool = "BOOL" + Boolean = "BOOLEAN" + + Serial = "SERIAL" + BigSerial = "BIGSERIAL" + + Json = "JSON" + Jsonb = "JSONB" + + Array = "ARRAY" + + SqlTypes = map[string]int{ + Bit: NUMERIC_TYPE, + TinyInt: NUMERIC_TYPE, + SmallInt: NUMERIC_TYPE, + MediumInt: NUMERIC_TYPE, + Int: NUMERIC_TYPE, + Integer: NUMERIC_TYPE, + BigInt: NUMERIC_TYPE, + + Enum: TEXT_TYPE, + Set: TEXT_TYPE, + Json: TEXT_TYPE, + Jsonb: TEXT_TYPE, + + Char: TEXT_TYPE, + NChar: TEXT_TYPE, + Varchar: TEXT_TYPE, + NVarchar: TEXT_TYPE, + TinyText: TEXT_TYPE, + Text: TEXT_TYPE, + NText: TEXT_TYPE, + MediumText: TEXT_TYPE, + LongText: TEXT_TYPE, + Uuid: TEXT_TYPE, + Clob: TEXT_TYPE, + SysName: TEXT_TYPE, + + Date: TIME_TYPE, + DateTime: TIME_TYPE, + Time: TIME_TYPE, + TimeStamp: TIME_TYPE, + TimeStampz: TIME_TYPE, + SmallDateTime: TIME_TYPE, + Year: TIME_TYPE, + + Decimal: NUMERIC_TYPE, + Numeric: NUMERIC_TYPE, + Real: NUMERIC_TYPE, + Float: NUMERIC_TYPE, + Double: NUMERIC_TYPE, + Money: NUMERIC_TYPE, + SmallMoney: NUMERIC_TYPE, + + Binary: BLOB_TYPE, + VarBinary: BLOB_TYPE, + + TinyBlob: BLOB_TYPE, + Blob: BLOB_TYPE, + MediumBlob: BLOB_TYPE, + LongBlob: BLOB_TYPE, + Bytea: BLOB_TYPE, + UniqueIdentifier: BLOB_TYPE, + + Bool: NUMERIC_TYPE, + + Serial: NUMERIC_TYPE, + BigSerial: NUMERIC_TYPE, + + Array: ARRAY_TYPE, + } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} +) + +// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison +var ( + c_EMPTY_STRING string + c_BOOL_DEFAULT bool + c_BYTE_DEFAULT byte + c_COMPLEX64_DEFAULT complex64 + c_COMPLEX128_DEFAULT complex128 + c_FLOAT32_DEFAULT float32 + c_FLOAT64_DEFAULT float64 + c_INT64_DEFAULT int64 + c_UINT64_DEFAULT uint64 + c_INT32_DEFAULT int32 + c_UINT32_DEFAULT uint32 + c_INT16_DEFAULT int16 + c_UINT16_DEFAULT uint16 + c_INT8_DEFAULT int8 + c_UINT8_DEFAULT uint8 + c_INT_DEFAULT int + c_UINT_DEFAULT uint + c_TIME_DEFAULT time.Time +) + +var ( + IntType = reflect.TypeOf(c_INT_DEFAULT) + Int8Type = reflect.TypeOf(c_INT8_DEFAULT) + Int16Type = reflect.TypeOf(c_INT16_DEFAULT) + Int32Type = reflect.TypeOf(c_INT32_DEFAULT) + Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + + UintType = reflect.TypeOf(c_UINT_DEFAULT) + Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) + Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) + Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) + Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + + Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) + Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + + Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) + Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) + + StringType = reflect.TypeOf(c_EMPTY_STRING) + BoolType = reflect.TypeOf(c_BOOL_DEFAULT) + ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + BytesType = reflect.SliceOf(ByteType) + + TimeType = reflect.TypeOf(c_TIME_DEFAULT) +) + +var ( + PtrIntType = reflect.PtrTo(IntType) + PtrInt8Type = reflect.PtrTo(Int8Type) + PtrInt16Type = reflect.PtrTo(Int16Type) + PtrInt32Type = reflect.PtrTo(Int32Type) + PtrInt64Type = reflect.PtrTo(Int64Type) + + PtrUintType = reflect.PtrTo(UintType) + PtrUint8Type = reflect.PtrTo(Uint8Type) + PtrUint16Type = reflect.PtrTo(Uint16Type) + PtrUint32Type = reflect.PtrTo(Uint32Type) + PtrUint64Type = reflect.PtrTo(Uint64Type) + + PtrFloat32Type = reflect.PtrTo(Float32Type) + PtrFloat64Type = reflect.PtrTo(Float64Type) + + PtrComplex64Type = reflect.PtrTo(Complex64Type) + PtrComplex128Type = reflect.PtrTo(Complex128Type) + + PtrStringType = reflect.PtrTo(StringType) + PtrBoolType = reflect.PtrTo(BoolType) + PtrByteType = reflect.PtrTo(ByteType) + + PtrTimeType = reflect.PtrTo(TimeType) +) + +// Type2SQLType generate SQLType acorrding Go's type +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + st = SQLType{Int, 0, 0} + case reflect.Int64, reflect.Uint64: + st = SQLType{BigInt, 0, 0} + case reflect.Float32: + st = SQLType{Float, 0, 0} + case reflect.Float64: + st = SQLType{Double, 0, 0} + case reflect.Complex64, reflect.Complex128: + st = SQLType{Varchar, 64, 0} + case reflect.Array, reflect.Slice, reflect.Map: + if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + st = SQLType{Blob, 0, 0} + } else { + st = SQLType{Text, 0, 0} + } + case reflect.Bool: + st = SQLType{Bool, 0, 0} + case reflect.String: + st = SQLType{Varchar, 255, 0} + case reflect.Struct: + if t.ConvertibleTo(TimeType) { + st = SQLType{DateTime, 0, 0} + } else { + // TODO need to handle association struct + st = SQLType{Text, 0, 0} + } + case reflect.Ptr: + st = Type2SQLType(t.Elem()) + default: + st = SQLType{Text, 0, 0} + } + return +} + +// default sql type change to go types +func SQLType2Type(st SQLType) reflect.Type { + name := strings.ToUpper(st.Name) + switch name { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: + return reflect.TypeOf(1) + case BigInt, BigSerial: + return reflect.TypeOf(int64(1)) + case Float, Real: + return reflect.TypeOf(float32(1)) + case Double: + return reflect.TypeOf(float64(1)) + case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: + return reflect.TypeOf("") + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: + return reflect.TypeOf([]byte{}) + case Bool: + return reflect.TypeOf(true) + case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: + return reflect.TypeOf(c_TIME_DEFAULT) + case Decimal, Numeric, Money, SmallMoney: + return reflect.TypeOf("") + default: + return reflect.TypeOf("") + } +} diff --git a/session.go b/session.go index c9679c2..5f73d04 100644 --- a/session.go +++ b/session.go @@ -6,37 +6,47 @@ package xorm import ( "context" + "crypto/rand" + "crypto/sha256" "database/sql" + "encoding/hex" "errors" "fmt" "hash/crc32" + "io" "reflect" "strings" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/contexts" + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/internal/statements" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/schemas" ) -type sessionType int +type sessionType bool const ( - engineSession sessionType = iota - groupSession + engineSession sessionType = false + groupSession sessionType = true ) // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { - db *core.DB engine *Engine tx *core.Tx - statement Statement + statement *statements.Statement currentTransaction *Transaction isAutoCommit bool isCommitedOrRollbacked bool isSqlFunc bool isAutoClose bool - + isClosed bool + prepareStmt bool // Automatically reset the statement after operations that execute a SQL // query such as Count(), Find(), Get(), ... autoResetStatement bool @@ -47,16 +57,12 @@ type Session struct { afterDeleteBeans map[interface{}]*[]func(interface{}) // -- - beforeClosures []func(interface{}) - afterClosures []func(interface{}) - + beforeClosures []func(interface{}) + afterClosures []func(interface{}) afterProcessors []executedProcessor - prepareStmt bool - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) - // !evalphobia! stored the last executed query on this session - //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} showSQL bool @@ -69,72 +75,105 @@ type Session struct { err error } -// Clone copy all the session's content and return a new session. -func (session *Session) Clone() *Session { - var sess = *session - return &sess +func newSessionID() string { + hash := sha256.New() + _, err := io.CopyN(hash, rand.Reader, 50) + if err != nil { + return "????????????????????" + } + md := hash.Sum(nil) + mdStr := hex.EncodeToString(md) + return mdStr[0:20] } -// Init reset the session as the init status. -func (session *Session) Init() { - session.statement.Init() - session.statement.Engine = session.engine - session.showSQL = session.engine.showSQL - session.isAutoCommit = true - session.isCommitedOrRollbacked = false - session.isAutoClose = false - session.isSqlFunc = false - session.autoResetStatement = true - session.prepareStmt = false - - // !nashtsai! is lazy init better? - session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) - session.beforeClosures = make([]func(interface{}), 0) - session.afterClosures = make([]func(interface{}), 0) - session.stmtCache = make(map[uint32]*core.Stmt) - - session.afterProcessors = make([]executedProcessor, 0) - - session.lastSQL = "" - session.lastSQLArgs = []interface{}{} +func newSession(engine *Engine) *Session { + var ctx context.Context + if engine.logSessionID { + ctx = context.WithValue(engine.defaultContext, log.SessionIDKey, newSessionID()) + } else { + ctx = engine.defaultContext + } - session.ctx = session.engine.defaultContext + return &Session{ + ctx: ctx, + engine: engine, + tx: nil, + statement: statements.NewStatement( + engine.dialect, + engine.tagParser, + engine.DatabaseTZ, + ), + isClosed: false, + isAutoCommit: true, + isCommitedOrRollbacked: false, + isAutoClose: false, + isSqlFunc: false, + autoResetStatement: true, + prepareStmt: false, + + afterInsertBeans: make(map[interface{}]*[]func(interface{}), 0), + afterUpdateBeans: make(map[interface{}]*[]func(interface{}), 0), + afterDeleteBeans: make(map[interface{}]*[]func(interface{}), 0), + beforeClosures: make([]func(interface{}), 0), + afterClosures: make([]func(interface{}), 0), + afterProcessors: make([]executedProcessor, 0), + stmtCache: make(map[uint32]*core.Stmt), + + lastSQL: "", + lastSQLArgs: make([]interface{}, 0), + + sessionType: engineSession, + } } // Close release the connection from pool -func (session *Session) Close() { +func (session *Session) Close() error { for _, v := range session.stmtCache { - v.Close() + if err := v.Close(); err != nil { + return err + } } - if session.db != nil { + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. if session.tx != nil && !session.isCommitedOrRollbacked { - session.Rollback() + if err := session.Rollback(); err != nil { + return err + } } session.tx = nil session.stmtCache = nil - session.db = nil + session.isClosed = true } + return nil +} + +func (session *Session) db() *core.DB { + return session.engine.db +} + +func (session *Session) getQueryer() core.Queryer { + if session.tx != nil { + return session.tx + } + return session.db() } // ContextCache enable context cache or not -func (session *Session) ContextCache(context ContextCache) *Session { - session.statement.context = context +func (session *Session) ContextCache(context contexts.ContextCache) *Session { + session.statement.SetContextCache(context) return session } // IsClosed returns if session is closed func (session *Session) IsClosed() bool { - return session.db == nil + return session.isClosed } func (session *Session) resetStatement() { if session.autoResetStatement { - session.statement.Init() + session.statement.Reset() } session.isSqlFunc = false } @@ -163,7 +202,9 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.statement.Table(tableNameOrBean) + if err := session.statement.SetTable(tableNameOrBean); err != nil { + session.statement.LastError = err + } return session } @@ -187,7 +228,7 @@ func (session *Session) ForUpdate() *Session { // NoAutoCondition disable generate SQL condition from beans func (session *Session) NoAutoCondition(no ...bool) *Session { - session.statement.NoAutoCondition(no...) + session.statement.SetNoAutoCondition(no...) return session } @@ -237,12 +278,12 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { } // MustLogSQL means record SQL or not and don't follow engine's setting -func (session *Session) MustLogSQL(log ...bool) *Session { - if len(log) > 0 { - session.showSQL = log[0] - } else { - session.showSQL = true +func (session *Session) MustLogSQL(logs ...bool) *Session { + var showSQL = true + if len(logs) > 0 { + showSQL = logs[0] } + session.ctx = context.WithValue(session.ctx, log.SessionShowSQLKey, showSQL) return session } @@ -273,17 +314,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { - if session.db == nil { - session.db = session.engine.db - session.stmtCache = make(map[uint32]*core.Stmt, 0) - } - return session.db -} - -func cleanupProcessorsClosures(slices *[]func(interface{})) { - if len(*slices) > 0 { - *slices = make([]func(interface{}), 0) - } + return session.db() } func (session *Session) canCache() bool { @@ -293,7 +324,7 @@ func (session *Session) canCache() bool { !session.statement.UseCache || session.statement.IsForUpdate || session.tx != nil || - len(session.statement.selectStr) > 0 { + len(session.statement.SelectStr) > 0 { return false } return true @@ -314,8 +345,8 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { - var col *core.Column +func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { + var col *schemas.Column if col = table.GetColumnIdx(key, idx); col == nil { return nil, ErrFieldIsNotExist{key, table.Name} } @@ -336,8 +367,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c type Cell *interface{} func (session *Session) rows2Beans(rows *core.Rows, fields []string, - table *core.Table, newElemFunc func([]string) reflect.Value, - sliceValueSetFunc func(*reflect.Value, core.PK) error) error { + table *schemas.Table, newElemFunc func([]string) reflect.Value, + sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -377,59 +408,20 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa return nil, err } - if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { - for ii, key := range fields { - b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) - } - } + executeBeforeSet(bean, fields, scanResults) + return scanResults, nil } -func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { - if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { - for ii, key := range fields { - b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) - } - } + executeAfterSet(bean, fields, scanResults) }() - // handle afterClosures - for _, closure := range session.afterClosures { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - closure(bean) - return nil - }, - session: session, - bean: bean, - }) - } - - if a, has := bean.(AfterLoadProcessor); has { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - a.AfterLoad() - return nil - }, - session: session, - bean: bean, - }) - } - - if a, has := bean.(AfterLoadSessionProcessor); has { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - a.AfterLoad(sess) - return nil - }, - session: session, - bean: bean, - }) - } + buildAfterProcessors(session, bean) var tempMap = make(map[string]int) - var pk core.PK + var pk schemas.PK for ii, key := range fields { var idx int var ok bool @@ -444,7 +436,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { if !strings.Contains(err.Error(), "is not valid") { - session.engine.logger.Warn(err) + session.engine.logger.Warnf("%v", err) } continue } @@ -459,7 +451,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if err := structConvert.FromDB(data); err != nil { return nil, err @@ -471,12 +463,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } - if _, ok := fieldValue.Interface().(core.Conversion); ok { + if _, ok := fieldValue.Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) } - fieldValue.Interface().(core.Conversion).FromDB(data) + fieldValue.Interface().(convert.Conversion).FromDB(data) } else { return nil, err } @@ -496,7 +488,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } else { return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) @@ -510,13 +502,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -533,20 +525,20 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } hasAssigned = true if len(bs) > 0 { if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -562,7 +554,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -615,16 +607,16 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { + if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone } - if rawValueType == core.TimeType { + if rawValueType == schemas.TimeType { hasAssigned = true - t := vv.Convert(core.TimeType).Interface().(time.Time) + t := vv.Convert(schemas.TimeType).Interface().(time.Time) z, _ := t.Zone() // set new location if database don't save timezone or give an incorrect timezone @@ -636,8 +628,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b t = t.In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == core.IntType || rawValueType == core.Int64Type || - rawValueType == core.Int32Type { + } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || + rawValueType == schemas.Int32Type { hasAssigned = true t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) @@ -647,7 +639,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -656,7 +648,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.str2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -669,7 +661,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) + session.engine.logger.Errorf("sql.Sanner error: %v", err) hasAssigned = false } } else if col.SQLType.IsJson() { @@ -677,7 +669,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { return nil, err } @@ -687,7 +679,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len(vv.Bytes()) > 0 { - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -695,7 +687,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return nil, err } @@ -704,13 +696,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if len(table.PrimaryKeys) != 1 { return nil, errors.New("unsupported non or composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -730,110 +722,110 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !nashtsai! TODO merge duplicated codes above switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: + case schemas.PtrStringType: if rawValueType.Kind() == reflect.String { x := vv.String() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrBoolType: + case schemas.PtrBoolType: if rawValueType.Kind() == reflect.Bool { x := vv.Bool() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { + case schemas.PtrTimeType: + if rawValueType == schemas.PtrTimeType { hasAssigned = true var x = rawValue.Interface().(time.Time) fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat64Type: + case schemas.PtrFloat64Type: if rawValueType.Kind() == reflect.Float64 { x := vv.Float() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint64Type: + case schemas.PtrUint64Type: if rawValueType.Kind() == reflect.Int64 { var x = uint64(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt64Type: + case schemas.PtrInt64Type: if rawValueType.Kind() == reflect.Int64 { x := vv.Int() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat32Type: + case schemas.PtrFloat32Type: if rawValueType.Kind() == reflect.Float64 { var x = float32(vv.Float()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrIntType: + case schemas.PtrIntType: if rawValueType.Kind() == reflect.Int64 { var x = int(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt32Type: + case schemas.PtrInt32Type: if rawValueType.Kind() == reflect.Int64 { var x = int32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt8Type: + case schemas.PtrInt8Type: if rawValueType.Kind() == reflect.Int64 { var x = int8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt16Type: + case schemas.PtrInt16Type: if rawValueType.Kind() == reflect.Int64 { var x = int16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUintType: + case schemas.PtrUintType: if rawValueType.Kind() == reflect.Int64 { var x = uint(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint32Type: + case schemas.PtrUint32Type: if rawValueType.Kind() == reflect.Int64 { var x = uint32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint8Type: + case schemas.Uint8Type: if rawValueType.Kind() == reflect.Int64 { var x = uint8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint16Type: + case schemas.Uint16Type: if rawValueType.Kind() == reflect.Int64 { var x = uint16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Complex64Type: + case schemas.Complex64Type: var x complex64 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } fieldValue.Set(reflect.ValueOf(&x)) } hasAssigned = true - case core.Complex128Type: + case schemas.Complex128Type: var x complex128 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } @@ -866,7 +858,7 @@ func (session *Session) saveLastSQL(sql string, args ...interface{}) { } func (session *Session) logSQL(sqlStr string, sqlArgs ...interface{}) { - if session.showSQL && !session.engine.showExecTime { + if session.showSQL { if len(sqlArgs) > 0 { session.engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) } else { @@ -882,7 +874,7 @@ func (session *Session) LastSQL() (string, []interface{}) { // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { - session.statement.Unscoped() + session.statement.SetUnscoped() return session } @@ -894,3 +886,19 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { fieldValue.SetUint(fieldValue.Uint() + 1) } } + +// Context sets the context on this session +func (session *Session) Context(ctx context.Context) *Session { + session.ctx = ctx + return session +} + +// PingContext test if database is ok +func (session *Session) PingContext(ctx context.Context) error { + if session.isAutoClose { + defer session.Close() + } + + session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) + return session.DB().PingContext(ctx) +} diff --git a/session_cols.go b/session_cols.go index f4e81c1..8e12da2 100644 --- a/session_cols.go +++ b/session_cols.go @@ -9,10 +9,10 @@ import ( "strings" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" ) -func setColumnInt(bean interface{}, col *core.Column, t int64) { +func setColumnInt(bean interface{}, col *schemas.Column, t int64) { v, err := col.ValueOf(bean) if err != nil { return @@ -27,7 +27,7 @@ func setColumnInt(bean interface{}, col *core.Column, t int64) { } } -func setColumnTime(bean interface{}, col *core.Column, t time.Time) { +func setColumnTime(bean interface{}, col *schemas.Column, t time.Time) { v, err := col.ValueOf(bean) if err != nil { return @@ -44,7 +44,7 @@ func setColumnTime(bean interface{}, col *core.Column, t time.Time) { } } -func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { +func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) { if len(m) == 0 { return false, false } @@ -77,14 +77,14 @@ func col2NewCols(columns ...string) []string { } // Incr provides a query string like "count = count + 1" -func (session *Session) Incr(column string, args ...interface{}) *Session { - session.statement.Incr(column, args...) +func (session *Session) Incr(column string, arg ...interface{}) *Session { + session.statement.Incr(column, arg...) return session } // Decr provides a query string like "count = count - 1" -func (session *Session) Decr(column string, args ...interface{}) *Session { - session.statement.Decr(column, args...) +func (session *Session) Decr(column string, arg ...interface{}) *Session { + session.statement.Decr(column, arg...) return session } diff --git a/session_cond.go b/session_cond.go index b2c51f5..beb4b85 100644 --- a/session_cond.go +++ b/session_cond.go @@ -6,14 +6,6 @@ package xorm import "github.com/xormplus/builder" -// Sql provides raw sql input parameter. When you have a complex SQL statement -// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. -// -// Deprecated: use SQL instead. -func (session *Session) Sql(query interface{}, args ...interface{}) *Session { - return session.SQL(query, args...) -} - // SQL provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. func (session *Session) SQL(query interface{}, args ...interface{}) *Session { @@ -40,13 +32,6 @@ func (session *Session) Or(query interface{}, args ...interface{}) *Session { return session } -// Id provides converting id as a query condition -// -// Deprecated: use ID instead -func (session *Session) Id(id interface{}) *Session { - return session.ID(id) -} - // ID provides converting id as a query condition func (session *Session) ID(id interface{}) *Session { session.statement.ID(id) @@ -67,5 +52,5 @@ func (session *Session) NotIn(column string, args ...interface{}) *Session { // Conds returns session query conditions except auto bean conditions func (session *Session) Conds() builder.Cond { - return session.statement.cond + return session.statement.Conds() } diff --git a/session_context.go b/session_context.go deleted file mode 100644 index 915f056..0000000 --- a/session_context.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "context" - -// Context sets the context on this session -func (session *Session) Context(ctx context.Context) *Session { - session.ctx = ctx - return session -} - -// PingContext test if database is ok -func (session *Session) PingContext(ctx context.Context) error { - if session.isAutoClose { - defer session.Close() - } - - session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) - return session.DB().PingContext(ctx) -} diff --git a/session_context_test.go b/session_context_test.go deleted file mode 100644 index 2784468..0000000 --- a/session_context_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestQueryContext(t *testing.T) { - type ContextQueryStruct struct { - Id int64 - Name string - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(ContextQueryStruct)) - - _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) - assert.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) - defer cancel() - - time.Sleep(time.Nanosecond) - - has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "context deadline exceeded") - assert.False(t, has) -} diff --git a/session_convert.go b/session_convert.go index a97e64a..9bdc2ce 100644 --- a/session_convert.go +++ b/session_convert.go @@ -14,10 +14,14 @@ import ( "strings" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/json" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) -func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) { +func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) { sdata := strings.TrimSpace(data) var x time.Time var err error @@ -27,7 +31,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti parseLoc = col.TimeZone } - if sdata == zeroTime0 || sdata == zeroTime1 { + if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 { } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) @@ -54,14 +58,14 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else if col.SQLType.Name == core.Time { + } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") sdata = ssd[1] } sdata = strings.TrimSpace(sdata) - if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { + if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } @@ -80,7 +84,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti return } -func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.Time, outErr error) { +func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { return session.str2Time(col, string(data)) } @@ -89,12 +93,12 @@ var ( ) // convert a db data([]byte) to a field value -func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { +func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { return structConvert.FromDB(data) } - if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { return structConvert.FromDB(data) } @@ -106,9 +110,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, case reflect.Complex64, reflect.Complex128: x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return err } fieldValue.Set(x.Elem()) @@ -120,9 +124,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if col.SQLType.IsText() { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return err } fieldValue.Set(x.Elem()) @@ -133,9 +137,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return err } fieldValue.Set(x.Elem()) @@ -157,8 +161,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var x int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && - session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API + if col.SQLType.Name == schemas.Bit && + session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API if len(data) == 1 { x = int64(data[0]) } else { @@ -199,7 +203,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) } } else { - if fieldType.ConvertibleTo(core.TimeType) { + if fieldType.ConvertibleTo(schemas.TimeType) { x, err := session.byte2Time(col, data) if err != nil { return err @@ -207,7 +211,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return err } @@ -217,14 +221,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("unsupported composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -247,11 +251,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, //typeStr := fieldType.String() switch fieldType.Elem().Kind() { // case "*string": - case core.StringType.Kind(): + case schemas.StringType.Kind(): x := string(data) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*bool": - case core.BoolType.Kind(): + case schemas.BoolType.Kind(): d := string(data) v, err := strconv.ParseBool(d) if err != nil { @@ -259,36 +263,36 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) // case "*complex64": - case core.Complex64Type.Kind(): + case schemas.Complex64Type.Kind(): var x complex64 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) } // case "*complex128": - case core.Complex128Type.Kind(): + case schemas.Complex128Type.Kind(): var x complex128 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) } // case "*float64": - case core.Float64Type.Kind(): + case schemas.Float64Type.Kind(): x, err := strconv.ParseFloat(string(data), 64) if err != nil { return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*float32": - case core.Float32Type.Kind(): + case schemas.Float32Type.Kind(): var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { @@ -297,7 +301,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = float32(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint64": - case core.Uint64Type.Kind(): + case schemas.Uint64Type.Kind(): var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -305,7 +309,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint": - case core.UintType.Kind(): + case schemas.UintType.Kind(): var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -314,7 +318,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint32": - case core.Uint32Type.Kind(): + case schemas.Uint32Type.Kind(): var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -323,7 +327,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint8": - case core.Uint8Type.Kind(): + case schemas.Uint8Type.Kind(): var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -332,7 +336,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint16": - case core.Uint16Type.Kind(): + case schemas.Uint16Type.Kind(): var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -341,12 +345,12 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int64": - case core.Int64Type.Kind(): + case schemas.Int64Type.Kind(): sdata := string(data) var x int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int64(data[0]) @@ -365,13 +369,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int": - case core.IntType.Kind(): + case schemas.IntType.Kind(): sdata := string(data) var x int var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int(data[0]) @@ -393,14 +397,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int32": - case core.Int32Type.Kind(): + case schemas.Int32Type.Kind(): sdata := string(data) var x int32 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && - session.engine.dialect.DBType() == core.MYSQL { + if col.SQLType.Name == schemas.Bit && + session.engine.dialect.URI().DBType == schemas.MYSQL { if len(data) == 1 { x = int32(data[0]) } else { @@ -421,13 +425,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int8": - case core.Int8Type.Kind(): + case schemas.Int8Type.Kind(): sdata := string(data) var x int8 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int8(data[0]) @@ -449,13 +453,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int16": - case core.Int16Type.Kind(): + case schemas.Int16Type.Kind(): sdata := string(data) var x int16 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int16(data[0]) @@ -480,7 +484,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, case reflect.Struct: switch fieldType { // case "*.time.Time": - case core.PtrTimeType: + case schemas.PtrTimeType: x, err := session.byte2Time(col, data) if err != nil { return err @@ -490,7 +494,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, default: if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) + table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) if err != nil { return err } @@ -499,14 +503,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("unsupported composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -536,9 +540,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } // convert a field value of a struct to interface for put into db -func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Value) (interface{}, error) { +func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { if fieldValue.CanAddr() { - if fieldConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { data, err := fieldConvert.ToDB() if err != nil { return 0, err @@ -550,7 +554,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } } - if fieldConvert, ok := fieldValue.Interface().(core.Conversion); ok { + if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok { data, err := fieldConvert.ToDB() if err != nil { return 0, err @@ -567,7 +571,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val if fieldValue.IsNil() { return nil, nil } else if !fieldValue.IsValid() { - session.engine.logger.Warn("the field[", col.FieldName, "] is invalid") + session.engine.logger.Warnf("the field [%s] is invalid", col.FieldName) return nil, nil } else { // !nashtsai! deference pointer type to instance type @@ -583,9 +587,9 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - tf := session.engine.formatColTime(col, t) + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t) return tf, nil } else if fieldType.ConvertibleTo(nullFloatType) { t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) @@ -601,7 +605,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return v.Value() } - fieldTable, err := session.engine.autoMapType(fieldValue) + fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue) if err != nil { return nil, err } @@ -613,25 +617,25 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return 0, err } return string(bytes), nil } else if col.SQLType.IsBlob() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return 0, err } return bytes, nil } return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) case reflect.Complex64, reflect.Complex128: - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return 0, err } return string(bytes), nil @@ -641,9 +645,9 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return 0, err } return string(bytes), nil @@ -654,9 +658,9 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val (fieldValue.Type().Elem().Kind() == reflect.Uint8) { bytes = fieldValue.Bytes() } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return 0, err } } diff --git a/session_delete.go b/session_delete.go index 1010d37..2ffd952 100644 --- a/session_delete.go +++ b/session_delete.go @@ -9,37 +9,38 @@ import ( "fmt" "strconv" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/schemas" ) -func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error { +func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, table) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) pkColumns := table.PKColumns() - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { resultsSlice, err := session.queryBytes(newsql, args...) if err != nil { return err } - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - var pk core.PK = make([]interface{}, 0) + var pk schemas.PK = make([]interface{}, 0) for _, col := range pkColumns { if v, ok := data[col.Name]; !ok { return errors.New("no id") @@ -61,14 +62,14 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, } for _, id := range ids { - session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id) + session.engine.logger.Debugf("[cache] delete cache obj: %v, %v", tableName, id) sid, err := id.ToString() if err != nil { return err } cacher.DelBean(tableName, sid) } - session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName) + session.engine.logger.Debugf("[cache] clear cache table: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -79,25 +80,21 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - // handle before delete processors - for _, closure := range session.beforeClosures { - closure(bean) - } - cleanupProcessorsClosures(&session.beforeClosures) + executeBeforeClosures(session, bean) if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { processor.BeforeDelete() } - condSQL, condArgs, err := session.statement.genConds(bean) + condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } @@ -126,23 +123,23 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case core.POSTGRES: + switch session.engine.dialect.URI().DBType { + case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { deleteSQL += " AND " + inSQL } else { deleteSQL += " WHERE " + inSQL } - case core.SQLITE: + case schemas.SQLITE: inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { deleteSQL += " AND " + inSQL } else { deleteSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? - case core.MSSQL: + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: return 0, ErrNotImplemented default: deleteSQL += orderSQL @@ -151,12 +148,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for cache. + // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) @@ -167,23 +164,23 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condSQL) if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case core.POSTGRES: + switch session.engine.dialect.URI().DBType { + case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { realSQL += " AND " + inSQL } else { realSQL += " WHERE " + inSQL } - case core.SQLITE: + case schemas.SQLITE: inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { realSQL += " AND " + inSQL } else { realSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? - case core.MSSQL: + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: return 0, ErrNotImplemented default: realSQL += orderSQL @@ -205,7 +202,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { + if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } diff --git a/session_exist.go b/session_exist.go index 072f0c5..e52c618 100644 --- a/session_exist.go +++ b/session_exist.go @@ -4,90 +4,19 @@ package xorm -import ( - "errors" - "fmt" - "reflect" - - "github.com/xormplus/builder" - "github.com/xormplus/core" -) - // Exist returns true if the record exist otherwise return false func (session *Session) Exist(bean ...interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } - var sqlStr string - var args []interface{} - var joinStr string - var err error - - if session.statement.RawSQL == "" { - if len(bean) == 0 { - tableName := session.statement.TableName() - if len(tableName) <= 0 { - return false, ErrTableNotFound - } - - tableName = session.statement.Engine.Quote(tableName) - if len(session.statement.JoinStr) > 0 { - joinStr = session.statement.JoinStr - } - - if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return false, err - } - - if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) - } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) - } - args = condArgs - } else { - if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) - } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) - } - args = []interface{}{} - } - } else { - beanValue := reflect.ValueOf(bean[0]) - if beanValue.Kind() != reflect.Ptr { - return false, errors.New("needs a pointer") - } - - if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean[0]); err != nil { - return false, err - } - } - - if len(session.statement.TableName()) <= 0 { - return false, ErrTableNotFound - } - session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean[0]) - if err != nil { - return false, err - } - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenExistSQL(bean...) + if err != nil { + return false, err } rows, err := session.queryRows(sqlStr, args...) diff --git a/session_find.go b/session_find.go index 72aa929..9426d90 100644 --- a/session_find.go +++ b/session_find.go @@ -8,10 +8,13 @@ import ( "errors" "fmt" "reflect" - "strings" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/internal/statements" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) const ( @@ -70,8 +73,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } session.autoResetStatement = true - if session.statement.selectStr != "" { - session.statement.selectStr = "" + if session.statement.SelectStr != "" { + session.statement.SelectStr = "" } if session.statement.OrderStr != "" { @@ -85,13 +88,14 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { defer session.resetStatement() - - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + var isSlice = sliceValue.Kind() == reflect.Slice + var isMap = sliceValue.Kind() == reflect.Map + if !isSlice && !isMap { return errors.New("needs a pointer to a slice or a map") } @@ -102,7 +106,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -110,7 +114,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -118,107 +122,54 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - var table = session.statement.RefTable - - var addedTableName = (len(session.statement.JoinStr) > 0) - var autoCond builder.Cond + var ( + table = session.statement.RefTable + addedTableName = (len(session.statement.JoinStr) > 0) + autoCond builder.Cond + ) if tp == tpStruct { - if !session.statement.noAutoCondition && len(condiBean) > 0 { + if !session.statement.NoAutoCondition && len(condiBean) > 0 { var err error - autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) + autoCond, err = session.statement.BuildConds(table, condiBean[0], true, true, false, true, addedTableName) if err != nil { return err } } else { - // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. - // See https://github.com/go-xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - - autoCond = session.engine.CondDeleted(col) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond = session.statement.CondDeleted(col) } } } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - if len(session.statement.TableName()) <= 0 { - return ErrTableNotFound - } - - var columnStr = session.statement.ColumnStr - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - session.statement.cond = session.statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return err + // if it's a map with Cols but primary key not in column list, we still need the primary key + if isMap && !session.statement.ColumnMap.IsEmpty() { + for _, k := range session.statement.RefTable.PrimaryKeys { + session.statement.ColumnMap.Add(k) } + } - args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenFindSQL(autoCond) + if err != nil { + return err } - if session.canCache() { - if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && + if session.statement.ColumnMap.IsEmpty() && session.canCache() { + if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.IsDistinct && - !session.statement.unscoped { + !session.statement.GetUnscoped() { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } err = nil // !nashtsai! reset err to nil for ErrCacheFailed - session.engine.logger.Warn("Cache Find Failed") + session.engine.logger.Warnf("Cache Find Failed") } } if sliceValue.Kind() != reflect.Map { if session.isSqlFunc { - var dialect = session.statement.Engine.Dialect() - rownumber := "xorm" + NewShortUUID().String() + var dialect = session.engine.Dialect() + rownumber := "xorm" + utils.NewShortUUID().String() sql := session.genSelectSql(dialect, rownumber) params := session.statement.RawParams @@ -245,7 +196,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) return session.noCacheFind(table, sliceValue, sqlStr, args...) } -func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { +func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { rows, err := session.queryRows(sqlStr, args...) if err != nil { return err @@ -284,10 +235,10 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return reflect.New(elemType) } - var containerValueSetFunc func(*reflect.Value, core.PK) error + var containerValueSetFunc func(*reflect.Value, schemas.PK) error if containerValue.Kind() == reflect.Slice { - containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { if isPointer { containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr())) } else { @@ -304,7 +255,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return errors.New("don't support multiple primary key's map has non-slice key type") } - containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { keyValue := reflect.New(keyType) err := convertPKToValue(table, keyValue.Interface(), pk) if err != nil { @@ -321,8 +272,8 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) - dataStruct := rValue(newValue.Interface()) - tb, err := session.engine.autoMapType(dataStruct) + dataStruct := utils.ReflectValue(newValue.Interface()) + tb, err := session.engine.tagParser.ParseWithCache(dataStruct) if err != nil { return err } @@ -358,7 +309,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return nil } -func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { +func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { return convertAssign(dst, pk[0]) @@ -370,28 +321,28 @@ func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if !session.canCache() || - indexNoCase(sqlStr, "having") != -1 || - indexNoCase(sqlStr, "group by") != -1 { + utils.IndexNoCase(sqlStr, "having") != -1 || + utils.IndexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } tableName := session.statement.TableName() - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) if cacher == nil { return nil } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } table := session.statement.RefTable - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { rows, err := session.queryRows(newsql, args...) if err != nil { @@ -400,11 +351,11 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in defer rows.Close() var i int - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) for rows.Next() { i++ if i > 500 { - session.engine.logger.Debug("[cacheFind] ids length > 500, no cache") + session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } var res = make([]string, len(table.PrimaryKeys)) @@ -412,7 +363,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { pk[i], err = session.engine.idTypeAssertion(col, res[i]) if err != nil { @@ -423,19 +374,19 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } - session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args) - err = core.PutCacheSql(cacher, ids, tableName, newsql, args) + session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) + err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return err } } else { - session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args) + session.engine.logger.Debugf("[cache] cache hit sql: %v, %v, %v, %v", tableName, sqlStr, newsql, args) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) ididxes := make(map[string]int) - var ides []core.PK + var ides []schemas.PK var temps = make([]interface{}, len(ids)) for idx, id := range ids { @@ -460,16 +411,20 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ides = append(ides, id) ididxes[sid] = idx } else { - session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) + session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean) + + pk, err := session.engine.IDOf(bean) + if err != nil { + return err + } - pk := session.engine.IdOf(bean) xid, err := pk.ToString() if err != nil { return err } if sid != xid { - session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean) + session.engine.logger.Errorf("[cache] error cache: %v, %v, %v", xid, sid, bean) return ErrCacheFailed } temps[idx] = bean @@ -480,6 +435,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in slices := reflect.New(reflect.SliceOf(t)) beans := slices.Interface() + statement := session.statement + session.statement = statements.NewStatement( + session.engine.dialect, + session.engine.tagParser, + session.engine.DatabaseTZ, + ) if len(table.PrimaryKeys) == 1 { ff := make([]interface{}, 0, len(ides)) for _, ie := range ides { @@ -502,6 +463,8 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return err } + session.statement = statement + vs := reflect.Indirect(reflect.ValueOf(beans)) for i := 0; i < vs.Len(); i++ { rv := vs.Index(i) @@ -519,7 +482,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in bean := rv.Interface() temps[ididxes[sid]] = bean - session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) + session.engine.logger.Debugf("[cache] cache bean: %v, %v, %v, %v", tableName, id, bean, temps) cacher.PutBean(tableName, sid, bean) } } @@ -527,7 +490,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) + session.engine.logger.Warnf("[cache] cache no hit: %v, %v, %v", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } @@ -548,7 +511,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } } else { if keyType.Kind() != reflect.Slice { - return errors.New("table have multiple primary keys, key is not core.PK or slice") + return errors.New("table have multiple primary keys, key is not schemas.PK or slice") } ikey = key } diff --git a/session_get.go b/session_get.go index 649ea76..932ffc0 100644 --- a/session_get.go +++ b/session_get.go @@ -11,7 +11,10 @@ import ( "reflect" "strconv" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) // Get retrieve one record from database, bean's non-empty fields @@ -26,8 +29,8 @@ func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) get(bean interface{}) (bool, error) { defer session.resetStatement() - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } beanValue := reflect.ValueOf(bean) @@ -38,7 +41,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return false, err } } @@ -52,12 +55,12 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrTableNotFound } session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean) + sqlStr, args, err = session.statement.GenGetSQL(bean) if err != nil { return false, err } } else { - sqlStr = session.statement.RawSQL + sqlStr = session.statement.GenRawSQL() params := session.statement.RawParams i := len(params) if i == 1 { @@ -74,9 +77,9 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable - if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && - !session.statement.unscoped { + if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct { + if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && + !session.statement.GetUnscoped() { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err @@ -84,11 +87,11 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - context := session.statement.context + context := session.statement.Context if context != nil { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { - session.engine.logger.Debug("hit context cache", sqlStr) + session.engine.logger.Debugf("hit context cache: %s", sqlStr) structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) @@ -110,7 +113,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } -func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { +func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { return false, err @@ -268,10 +271,10 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea if err != nil { return false, err } - // close it before covert data + // close it before convert data rows.Close() - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil { return true, err @@ -299,19 +302,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return false, ErrCacheFailed } tableName := session.statement.TableName() - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) - session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) + session.engine.logger.Debugf("[cache] Get SQL: %s, %v", newsql, args) table := session.statement.RefTable - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { var res = make([]string, len(table.PrimaryKeys)) rows, err := session.NoCache().queryRows(newsql, args...) @@ -329,7 +332,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, ErrCacheFailed } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { if col.SQLType.IsText() { pk[i] = res[i] @@ -344,20 +347,20 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } } - ids = []core.PK{pk} - session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) - err = core.PutCacheSql(cacher, ids, tableName, newsql, args) + ids = []schemas.PK{pk} + session.engine.logger.Debugf("[cache] cache ids: %s, %v", newsql, ids) + err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } } else { - session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids) + session.engine.logger.Debugf("[cache] cache hit: %s, %v", newsql, ids) } if len(ids) > 0 { structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] - session.engine.logger.Debug("[cacheGet] get bean:", tableName, id) + session.engine.logger.Debugf("[cache] get bean: %s, %v", tableName, id) sid, err := id.ToString() if err != nil { return false, err @@ -370,10 +373,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return has, err } - session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache bean: %s, %v, %v", tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean) } else { - session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache hit: %s, %v, %v", tableName, id, cacheBean) has = true } structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) diff --git a/session_insert.go b/session_insert.go index dcc7060..dc37ae6 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,8 +12,9 @@ import ( "strconv" "strings" - "github.com/xormplus/builder" - "github.com/xormplus/core" + // "github.com/xormplus/builder" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) // ErrNoElementsOnSlice represents an error there is no element when insert @@ -74,21 +75,11 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { return 0, ErrNoElementsOnSlice } - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err } + affected += cnt } else { cnt, err := session.innerInsert(bean) if err != nil { @@ -112,7 +103,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { + if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } @@ -121,17 +112,24 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, ErrTableNotFound } - table := session.statement.RefTable - size := sliceValue.Len() - - var colNames []string - var colMultiPlaces []string - var args []interface{} - var cols []*core.Column + var ( + table = session.statement.RefTable + size = sliceValue.Len() + colNames []string + colMultiPlaces []string + args []interface{} + cols []*schemas.Column + ) for i := 0; i < size; i++ { v := sliceValue.Index(i) - vv := reflect.Indirect(v) + var vv reflect.Value + switch v.Kind() { + case reflect.Interface: + vv = reflect.Indirect(v.Elem()) + default: + vv = reflect.Indirect(v) + } elemValue := v.Interface() var colPlaces []string @@ -146,123 +144,77 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } // -- - if i == 0 { - for _, col := range table.Columns() { - ptrFieldValue, err := col.ValueOfV(&vv) + for _, col := range table.Columns() { + ptrFieldValue, err := col.ValueOfV(&vv) + if err != nil { + return 0, err + } + fieldValue := *ptrFieldValue + if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + continue + } + if col.MapType == schemas.ONLYFROMDB { + continue + } + if col.IsDeleted { + continue + } + if session.statement.OmitColumnMap.Contain(col.Name) { + continue + } + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { + continue + } + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.CheckVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) + } else { + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return 0, err } - fieldValue := *ptrFieldValue - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } + args = append(args, arg) + } + if i == 0 { colNames = append(colNames, col.Name) cols = append(cols, col) - colPlaces = append(colPlaces, "?") - } - } else { - for _, col := range cols { - ptrFieldValue, err := col.ValueOfV(&vv) - if err != nil { - return 0, err - } - fieldValue := *ptrFieldValue - - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } - - colPlaces = append(colPlaces, "?") } + colPlaces = append(colPlaces, "?") } + colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) } cleanupProcessorsClosures(&session.beforeClosures) + quoter := session.engine.dialect.Quoter() var sql string - if session.engine.dialect.DBType() == core.ORACLE { + colStr := quoter.Join(colNames, ",") + if session.engine.dialect.URI().DBType == schemas.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ",")) + quoter.Quote(tableName), + colStr) sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, temp)) } else { sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -283,7 +235,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error for _, closure := range session.afterClosures { closure(elemValue) } - if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + if processor, ok := elemValue.(AfterInsertProcessor); ok { processor.AfterInsert() } } else { @@ -296,7 +248,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error session.afterInsertBeans[elemValue] = &afterClosures } } else { - if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + if _, ok := elemValue.(AfterInsertProcessor); ok { session.afterInsertBeans[elemValue] = nil } } @@ -315,27 +267,24 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { - return 0, ErrParamsType - + return 0, ErrPtrSliceType } if sliceValue.Len() <= 0 { - return 0, nil + return 0, ErrNoElementsOnSlice } return session.innerInsertMulti(rowsSlicePtr) } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - table := session.statement.RefTable - // handle BeforeInsertProcessor for _, closure := range session.beforeClosures { closure(bean) @@ -346,101 +295,19 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } + var tableName = session.statement.TableName() + table := session.statement.RefTable + colNames, args, err := session.genInsertColumns(bean) if err != nil { return 0, err } - exprs := session.statement.exprColumns - colPlaces := strings.Repeat("?, ", len(colNames)) - if exprs.Len() <= 0 && len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - - var tableName = session.statement.TableName() - var output string - if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { - output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) - } - - var buf = builder.NewWriter() - if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + sqlStr, args, err := session.statement.GenInsertSQL(colNames, args) + if err != nil { return 0, err } - if len(colPlaces) <= 0 { - if session.engine.dialect.DBType() == core.MYSQL { - if _, err := buf.WriteString(" VALUES ()"); err != nil { - - return 0, err - } - - } else { - if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { - return 0, err - } - - } - } else { - if _, err := buf.WriteString(" ("); err != nil { - return 0, err - } - if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if session.statement.cond.IsValid() { - if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(buf, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := buf.WriteString(","); err != nil { - return 0, err - } - } - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(buf); err != nil { - return 0, err - } - } else { - buf.Append(args...) - if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v", - output, - colPlaces)); err != nil { - return 0, err - } - - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(")"); err != nil { - return 0, err - } - } - } - - if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { - return 0, err - } - } - - sqlStr := buf.String() - args = buf.Args() - handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { @@ -471,7 +338,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { res, err := session.queryBytes("select seq_atable.currval from dual", args...) if err != nil { return 0, err @@ -481,10 +348,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Errorf("[SQL][%p] %v", session, err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -512,7 +379,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) { + } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || + session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) if err != nil { @@ -522,7 +390,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("[SQL][%p] %v", session, err) @@ -553,49 +421,48 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else { - res, err := session.exec(sqlStr, args...) - if err != nil { - session.engine.logger.Errorf("[SQL][%p] %v", session, err) - return 0, err - } - - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) + } - if table.Version != "" && session.statement.checkVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("[SQL][%p] %v", session, err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } - if table.AutoIncrement == "" { - return res.RowsAffected() - } + defer handleAfterInsertProcessorFunc(bean) - var id int64 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() - } + session.cacheInsert(tableName) - aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if table.Version != "" && session.statement.CheckVersion { + verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Errorf("[SQL][%p] %v", session, err) + session.engine.logger.Errorf("%v", err) + } else if verValue.IsValid() && verValue.CanSet() { + session.incrVersionFieldValue(verValue) } + } - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return res.RowsAffected() - } + if table.AutoIncrement == "" { + return res.RowsAffected() + } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + var id int64 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { return res.RowsAffected() } + + aiValue.Set(int64ToIntValue(id, aiValue.Type())) + + return res.RowsAffected() } // InsertOne insert only one struct into database as a record. @@ -613,11 +480,11 @@ func (session *Session) cacheInsert(table string) error { if !session.statement.UseCache { return nil } - cacher := session.engine.getCacher(table) + cacher := session.engine.cacherMgr.GetCacher(table) if cacher == nil { return nil } - session.engine.logger.Debug("[cache] clear sql:", table) + session.engine.logger.Debugf("[cache] clear SQL: %v", table) cacher.ClearIds(table) return nil } @@ -629,7 +496,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac args := make([]interface{}, 0, len(table.ColumnsSeq())) for _, col := range table.Columns() { - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -637,19 +504,19 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } @@ -659,30 +526,13 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZeroValue(fieldValue) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -698,10 +548,10 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } @@ -724,9 +574,9 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -737,68 +587,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err args = append(args, m[colName]) } - w := builder.NewWriter() - if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if _, err := w.WriteString(") SELECT "); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(w, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.writeArgs(w); err != nil { - return 0, err - } - } - - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(w); err != nil { - return 0, err - } - - } else { - qm := strings.Repeat("?,", len(columns)) - qm = qm[:len(qm)-1] - - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { - return 0, err - } - w.Append(args...) - } - - sql := w.String() - args = w.Args() - - if err := session.cacheInsert(tableName); err != nil { - return 0, err - } - - res, err := session.exec(sql, args...) - if err != nil { - session.engine.logger.Errorf("[SQL][%p] %v", session, err) - return 0, err - } - affected, err := res.RowsAffected() - if err != nil { - return 0, err - } - return affected, nil + return session.insertMap(columns, args) } func (session *Session) insertMapString(m map[string]string) (int64, error) { @@ -812,12 +601,13 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } + sort.Strings(columns) var args = make([]interface{}, 0, len(m)) @@ -825,52 +615,19 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { args = append(args, m[colName]) } - w := builder.NewWriter() - if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if _, err := w.WriteString(") SELECT "); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(w, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.writeArgs(w); err != nil { - return 0, err - } - } - - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(w); err != nil { - return 0, err - } - } else { - qm := strings.Repeat("?,", len(columns)) - qm = qm[:len(qm)-1] + return session.insertMap(columns, args) +} - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { - return 0, err - } - w.Append(args...) +func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) { + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound } - sql := w.String() - args = w.Args() + sql, args, err := session.statement.GenInsertMapSQL(columns, args) + if err != nil { + return 0, err + } if err := session.cacheInsert(tableName); err != nil { return 0, err diff --git a/session_iterate.go b/session_iterate.go index 4a3cc08..a5c74f7 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -6,6 +6,8 @@ package xorm import ( "reflect" + + "github.com/xormplus/xorm/internal/utils" ) // IterFunc only use by Iterate @@ -25,11 +27,11 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { defer session.Close() } - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } - if session.statement.bufferSize > 0 { + if session.statement.BufferSize > 0 { return session.bufferIterate(bean, fun) } @@ -57,18 +59,18 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // BufferSize sets the buffersize for iterate func (session *Session) BufferSize(size int) *Session { - session.statement.bufferSize = size + session.statement.BufferSize = size return session } func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { - var bufferSize = session.statement.bufferSize + var bufferSize = session.statement.BufferSize var pLimitN = session.statement.LimitN if pLimitN != nil && bufferSize > *pLimitN { bufferSize = *pLimitN } var start = session.statement.Start - v := rValue(bean) + v := utils.ReflectValue(bean) sliceType := reflect.SliceOf(v.Type()) var idx = 0 session.autoResetStatement = false diff --git a/session_pk_test.go b/session_pk_test.go deleted file mode 100644 index 275086c..0000000 --- a/session_pk_test.go +++ /dev/null @@ -1,1199 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/xormplus/core" -) - -type IntId struct { - Id int `xorm:"pk autoincr"` - Name string -} - -type Int16Id struct { - Id int16 `xorm:"pk autoincr"` - Name string -} - -type Int32Id struct { - Id int32 `xorm:"pk autoincr"` - Name string -} - -type UintId struct { - Id uint `xorm:"pk autoincr"` - Name string -} - -type Uint16Id struct { - Id uint16 `xorm:"pk autoincr"` - Name string -} - -type Uint32Id struct { - Id uint32 `xorm:"pk autoincr"` - Name string -} - -type Uint64Id struct { - Id uint64 `xorm:"pk autoincr"` - Name string -} - -type StringPK struct { - Id string `xorm:"pk notnull"` - Name string -} - -type ID int64 -type MyIntPK struct { - ID ID `xorm:"pk autoincr"` - Name string -} - -type StrID string -type MyStringPK struct { - ID StrID `xorm:"pk notnull"` - Name string -} - -func TestIntId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&IntId{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(IntId) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]IntId, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int]IntId) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestInt16Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Int16Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Int16Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Int16Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int16]Int16Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestInt32Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Int32Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Int32Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Int32Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int32]Int32Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUintId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&UintId{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - var inserts = []UintId{ - {Name: "test1"}, - {Name: "test2"}, - } - cnt, err = testEngine.Insert(&inserts) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 2 { - err = errors.New("insert count should be two") - t.Error(err) - panic(err) - } - - bean := new(UintId) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]UintId, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 3 { - err = errors.New("get count should be three") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint]UintId, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 3 { - err = errors.New("get count should be three") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint16Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Uint16Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint16Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Uint16Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint16]Uint16Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint32Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Uint32Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint32Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Uint32Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint32]Uint32Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint64Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &Uint64Id{Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint64Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.Id != idbean.Id { - panic(errors.New("should be equal")) - } - - beans := make([]Uint64Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[uint64]Uint64Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.Id] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestStringPK(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&StringPK{Id: "1-1-2", Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(StringPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]StringPK, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[string]StringPK) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -type CompositeKey struct { - Id1 int64 `xorm:"id1 pk"` - Id2 int64 `xorm:"id2 pk"` - UpdateStr string -} - -func TestCompositeKey(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&CompositeKey{11, 22, ""}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert CompositeKey{11, 22}")) - } - - cnt, err = testEngine.Insert(&CompositeKey{11, 22, ""}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted CompositeKey{11, 22}")) - } - - var compositeKeyVal CompositeKey - has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } - - var compositeKeyVal2 CompositeKey - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } - - if compositeKeyVal != compositeKeyVal2 { - t.Error(errors.New("should be equal")) - } - - var cps = make([]CompositeKey, 0) - err = testEngine.Find(&cps) - if err != nil { - t.Error(err) - } - if len(cps) != 1 { - t.Error(errors.New("should has one record")) - } - if cps[0] != compositeKeyVal { - t.Error(errors.New("should be equal")) - } - - cnt, err = testEngine.Insert(&CompositeKey{22, 22, ""}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert CompositeKey{22, 22}")) - } - - cps = make([]CompositeKey, 0) - err = testEngine.Find(&cps) - assert.NoError(t, err) - assert.EqualValues(t, 2, len(cps), "should has two record") - assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") - - compositeKeyVal = CompositeKey{UpdateStr: "test1"} - cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update CompositeKey{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -func TestCompositeKey2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type User struct { - UserId string `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` - } - - err := testEngine.DropTables(&User{}) - - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&User{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&User{"11", "nick", 22, 5}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert User{11, 22}")) - } - - cnt, err = testEngine.Insert(&User{"11", "nick", 22, 6}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted User{11, 22}")) - } - - var user User - has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - user = User{NickName: "test1"} - cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update User{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -type MyString string -type UserPK2 struct { - UserId MyString `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` -} - -func TestCompositeKey3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&UserPK2{}) - - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&UserPK2{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&UserPK2{"11", "nick", 22, 5}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert User{11, 22}")) - } - - cnt, err = testEngine.Insert(&UserPK2{"11", "nick", 22, 6}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted User{11, 22}")) - } - - var user UserPK2 - has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - user = UserPK2{NickName: "test1"} - cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update User{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -func TestMyIntId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &MyIntPK{Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(MyIntPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.ID != idbean.ID { - panic(errors.New("should be equal")) - } - - var beans []MyIntPK - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[ID]MyIntPK, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.ID] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestMyStringId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &MyStringPK{ID: "1111", Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(MyStringPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.ID != idbean.ID { - panic(errors.New("should be equal")) - } - - var beans []MyStringPK - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[StrID]MyStringPK, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.ID] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestSingleAutoIncrColumn(t *testing.T) { - type Account struct { - Id int64 `xorm:"pk autoincr"` - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(Account)) - - _, err := testEngine.Insert(&Account{}) - assert.NoError(t, err) -} - -func TestCompositePK(t *testing.T) { - type TaskSolution struct { - UID string `xorm:"notnull pk UUID 'uid'"` - TID string `xorm:"notnull pk UUID 'tid'"` - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` - } - - assert.NoError(t, prepareEngine()) - - tables1, err := testEngine.DBMetas() - assert.NoError(t, err) - - assertSync(t, new(TaskSolution)) - assert.NoError(t, testEngine.Sync2(new(TaskSolution))) - - tables2, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1+len(tables1), len(tables2)) - - var table *core.Table - for _, t := range tables2 { - if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { - table = t - break - } - } - - assert.NotEqual(t, nil, table) - - pkCols := table.PKColumns() - assert.EqualValues(t, 2, len(pkCols)) - names := []string{pkCols[0].Name, pkCols[1].Name} - sort.Strings(names) - assert.EqualValues(t, []string{"tid", "uid"}, names) -} - -func TestNoPKIdQueryUpdate(t *testing.T) { - type NoPKTable struct { - Username string - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(NoPKTable)) - - cnt, err := testEngine.Insert(&NoPKTable{ - Username: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var res NoPKTable - has, err := testEngine.ID("test").Get(&res) - assert.Error(t, err) - assert.False(t, has) - - cnt, err = testEngine.ID("test").Update(&NoPKTable{ - Username: "test1", - }) - assert.Error(t, err) - assert.EqualValues(t, 0, cnt) - - type UnvalidPKTable struct { - ID int `xorm:"id"` - Username string - } - - assertSync(t, new(UnvalidPKTable)) - - cnt, err = testEngine.Insert(&UnvalidPKTable{ - ID: 1, - Username: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var res2 UnvalidPKTable - has, err = testEngine.ID(1).Get(&res2) - assert.Error(t, err) - assert.False(t, has) - - cnt, err = testEngine.ID(1).Update(&UnvalidPKTable{ - Username: "test1", - }) - assert.Error(t, err) - assert.EqualValues(t, 0, cnt) -} diff --git a/session_plus.go b/session_plus.go index 4e18e37..4571df1 100644 --- a/session_plus.go +++ b/session_plus.go @@ -12,10 +12,12 @@ import ( "reflect" "regexp" "strings" - "time" "github.com/Chronokeeper/anyxml" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) type Record map[string]Value @@ -421,21 +423,21 @@ func (resultStructs *ResultStructs) XmlIndent(prefix string, indent string, reco } func (session *Session) SqlMapClient(sqlTagName string, args ...interface{}) *Session { - return session.Sql(session.engine.SqlMap.Sql[sqlTagName], args...) + return session.SQL(session.engine.SqlMap.Sql[sqlTagName], args...) } func (session *Session) SqlTemplateClient(sqlTagName string, args ...interface{}) *Session { session.isSqlFunc = true sql, err := session.engine.SqlTemplate.Execute(sqlTagName, args...) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if len(args) == 0 { - return session.Sql(sql) + return session.SQL(sql) } else { map1 := args[0].(*map[string]interface{}) - return session.Sql(sql, map1) + return session.SQL(sql, map1) } } @@ -446,13 +448,13 @@ func (session *Session) Search(rowsSlicePtr interface{}, condiBean ...interface{ return r } -func (session *Session) genSelectSql(dialect core.Dialect, rownumber string) string { +func (session *Session) genSelectSql(dialect dialects.Dialect, rownumber string) string { var sql = session.statement.RawSQL var orderBys = session.statement.OrderStr pLimitN := session.statement.LimitN - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { + if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if session.statement.Start > 0 { sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, session.statement.LimitN, session.statement.Start) if pLimitN != nil { @@ -463,7 +465,7 @@ func (session *Session) genSelectSql(dialect core.Dialect, rownumber string) str } else if pLimitN != nil { sql = fmt.Sprintf("%v LIMIT %v", sql, session.statement.LimitN) } - } else if dialect.DBType() == core.ORACLE { + } else if dialect.URI().DBType == schemas.ORACLE { if session.statement.Start != 0 || pLimitN != nil { sql = fmt.Sprintf("SELECT aat.* FROM (SELECT at.*,ROWNUM %v FROM (%v) at WHERE ROWNUM <= %d) aat WHERE %v > %d", rownumber, sql, session.statement.Start+*pLimitN, rownumber, session.statement.Start) @@ -522,8 +524,8 @@ func (session *Session) Query() *ResultMap { defer session.Close() } - var dialect = session.statement.Engine.Dialect() - rownumber := "xorm" + NewShortUUID().String() + var dialect = session.engine.Dialect() + rownumber := "xorm" + utils.NewShortUUID().String() sql := session.genSelectSql(dialect, rownumber) params := session.statement.RawParams @@ -542,13 +544,13 @@ func (session *Session) Query() *ResultMap { result, err = session.queryAll(sql, params...) } pLimitN := session.statement.LimitN - if dialect.DBType() == core.MSSQL { + if dialect.URI().DBType == schemas.MSSQL { if session.statement.Start > 0 { for i, _ := range result { delete(result[i], rownumber) } } - } else if dialect.DBType() == core.ORACLE { + } else if dialect.URI().DBType == schemas.ORACLE { if session.statement.Start != 0 || pLimitN != nil { for i, _ := range result { delete(result[i], rownumber) @@ -566,8 +568,8 @@ func (session *Session) QueryWithDateFormat(dateFormat string) *ResultMap { defer session.Close() } - var dialect = session.statement.Engine.Dialect() - rownumber := "xorm" + NewShortUUID().String() + var dialect = session.engine.Dialect() + rownumber := "xorm" + utils.NewShortUUID().String() sql := session.genSelectSql(dialect, rownumber) params := session.statement.RawParams @@ -587,13 +589,13 @@ func (session *Session) QueryWithDateFormat(dateFormat string) *ResultMap { } pLimitN := session.statement.LimitN - if dialect.DBType() == core.MSSQL { + if dialect.URI().DBType == schemas.MSSQL { if session.statement.Start > 0 { for i, _ := range result { delete(result[i], rownumber) } } - } else if dialect.DBType() == core.ORACLE { + } else if dialect.URI().DBType == schemas.ORACLE { if session.statement.Start != 0 || pLimitN != nil { for i, _ := range result { delete(result[i], rownumber) @@ -633,24 +635,11 @@ func (session *Session) Execute() (sql.Result, error) { // ============================= func (session *Session) queryAll(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string]interface{}, err error) { session.queryPreprocess(&sqlStr, paramStr...) - - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(paramStr) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr, paramStr, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr, execDuration) - } - }() + if session.showSQL { + if len(paramStr) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, paramStr) } else { - if len(paramStr) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, paramStr) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) } } @@ -665,23 +654,11 @@ func (session *Session) queryAllByMap(sqlStr string, paramMap interface{}) (resu session.queryPreprocess(&sqlStr1, param...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(param) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr1, param, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr1, execDuration) - } - }() + if session.showSQL { + if len(param) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr1, param) } else { - if len(param) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr1, param) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr1) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr1) } } @@ -695,23 +672,11 @@ func (session *Session) queryAllByMapWithDateFormat(dateFormat string, sqlStr st sqlStr1, param, _ := core.MapToSlice(sqlStr, paramMap) session.queryPreprocess(&sqlStr1, param...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(param) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr1, param, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr1, execDuration) - } - }() + if session.showSQL { + if len(param) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr1, param) } else { - if len(param) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr1, param) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr1) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr1) } } @@ -724,23 +689,11 @@ func (session *Session) queryAllByMapWithDateFormat(dateFormat string, sqlStr st func (session *Session) queryAllWithDateFormat(dateFormat string, sqlStr string, paramStr ...interface{}) (resultsSlice []map[string]interface{}, err error) { session.queryPreprocess(&sqlStr, paramStr...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(paramStr) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr, paramStr, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr, execDuration) - } - }() + if session.showSQL { + if len(paramStr) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, paramStr) } else { - if len(paramStr) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, paramStr) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) } } @@ -844,11 +797,12 @@ func (session *Session) queryPreprocessByMap(sqlStr *string, paramMap interface{ }) for _, filter := range session.engine.dialect.Filters() { - query = filter.Do(query, session.engine.dialect, session.statement.RefTable) + query = filter.Do(query) } *sqlStr = query - session.engine.logSQL(session, *sqlStr, paramMap) + // session.engine.logSQL(session, *sqlStr, paramMap) + session.logSQL(*sqlStr, paramMap) } func (session *Session) Sqls(sqls interface{}, parmas ...interface{}) *SqlsExecutor { diff --git a/session_query.go b/session_query.go index ab5ee44..37a4d76 100644 --- a/session_query.go +++ b/session_query.go @@ -8,101 +8,22 @@ import ( "fmt" "reflect" "strconv" - "strings" + + // "strings" "time" - "github.com/xormplus/builder" - "github.com/xormplus/core" + // "github.com/xormplus/builder" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/schemas" + // "github.com/xormplus/xorm/internal/statements" ) -func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { - if len(sqlOrArgs) > 0 { - return convertSQLOrArgs(sqlOrArgs...) - } - - if session.statement.RawSQL != "" { - var dialect = session.statement.Engine.Dialect() - rownumber := "xorm" + NewShortUUID().String() - sql := session.genSelectSql(dialect, rownumber) - - params := session.statement.RawParams - i := len(params) - - // var result []map[string]interface{} - // var err error - if i == 1 { - vv := reflect.ValueOf(params[0]) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return sql, params, nil - } else { - sqlStr1, param, _ := core.MapToSlice(sql, params[0]) - return sqlStr1, param, nil - } - } else { - return sql, params, nil - } - // return session.statement.RawSQL, session.statement.RawParams, nil - } - - if len(session.statement.TableName()) <= 0 { - return "", nil, ErrTableNotFound - } - - var columnStr = session.statement.ColumnStr - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - if err := session.statement.processIDParam(); err != nil { - return "", nil, err - } - - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return "", nil, err - } - - args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil -} - func (session *Session) QueryValue(sqlOrArgs ...interface{}) ([]map[string]Value, error) { if session.isAutoClose { defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -117,7 +38,7 @@ func (session *Session) QueryResult(sqlOrArgs ...interface{}) *ResultValue { defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return &ResultValue{Error: err} } @@ -132,7 +53,7 @@ func (session *Session) QueryBytes(sqlOrArgs ...interface{}) ([]map[string][]byt defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -165,8 +86,8 @@ func value2String(rawValue *reflect.Value) (str string, err error) { } // time type case reflect.Struct: - if aa.ConvertibleTo(core.TimeType) { - str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + if aa.ConvertibleTo(schemas.TimeType) { + str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) } else { err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) } @@ -271,6 +192,7 @@ func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { } resultsSlice = append(resultsSlice, record) } + return resultsSlice, nil } @@ -279,7 +201,7 @@ func (session *Session) QueryRows(sqlOrArgs ...interface{}) (*core.Rows, error) defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -298,7 +220,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -318,7 +240,7 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -371,7 +293,7 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -391,15 +313,15 @@ func (session *Session) QueryExpr(sqlOrArgs ...interface{}) sqlExpr { defer session.Close() } - sqlStr, args, err := session.genQuerySQL() + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return sqlExpr{sqlExpr: ""} } sqlStr, err = ConvertToBoundSQL(sqlStr, args) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) return sqlExpr{sqlExpr: ""} } diff --git a/session_raw.go b/session_raw.go index 8060a3b..0b6cf9e 100644 --- a/session_raw.go +++ b/session_raw.go @@ -7,15 +7,15 @@ package xorm import ( "database/sql" "reflect" - "time" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/core" + // "github.com/xormplus/xorm/internal/statements" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable) + *sqlStr = filter.Do(*sqlStr) } session.lastSQL = *sqlStr @@ -24,28 +24,19 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) { defer session.resetStatement() + if session.statement.LastError != nil { + return nil, session.statement.LastError + } session.queryPreprocess(&sqlStr, args...) + session.lastSQL = sqlStr + session.lastSQLArgs = args if session.showSQL { - session.lastSQL = sqlStr - session.lastSQLArgs = args - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr, execDuration) - } - }() + if len(args) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, args) } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, args) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) } } @@ -270,23 +261,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.queryPreprocess(&sqlStr, args...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL][%p] %s %#v - took: %v", session, sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL][%p] %s - took: %v", session, sqlStr, execDuration) - } - }() + session.lastSQL = sqlStr + session.lastSQLArgs = args + + if session.showSQL { + if len(args) > 0 { + session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, args) } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL][%p] %v %#v", session, sqlStr, args) - } else { - session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) - } + session.engine.logger.Infof("[SQL][%p] %v", session, sqlStr) } } diff --git a/session_schema.go b/session_schema.go index 699f3f4..1273b70 100644 --- a/session_schema.go +++ b/session_schema.go @@ -5,11 +5,15 @@ package xorm import ( + "bufio" "database/sql" "fmt" + "io" + "os" "strings" - "github.com/xormplus/core" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) // Ping test if database is ok @@ -32,13 +36,18 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqlStr := session.statement.genCreateTableSQL() - _, err := session.exec(sqlStr) - return err + sqlStrs := session.statement.GenCreateTableSQL() + for _, s := range sqlStrs { + _, err := session.exec(s) + if err != nil { + return err + } + } + return nil } // CreateIndexes create indexes @@ -51,11 +60,11 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genIndexSQL() + sqls := session.statement.GenIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -74,11 +83,11 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genUniqueSQL() + sqls := session.statement.GenUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -98,11 +107,11 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genDelIndexSQL() + sqls := session.statement.GenDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -123,18 +132,16 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error { tableName := session.engine.TableName(beanOrTableName) - var needDrop = true - if !session.engine.dialect.SupportDropIfExists() { - sqlStr, args := session.engine.dialect.TableCheckSql(tableName) - results, err := session.queryBytes(sqlStr, args...) + sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) + if !checkIfExist { + exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { return err } - needDrop = len(results) > 0 + checkIfExist = exist } - if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) + if checkIfExist { _, err := session.exec(sqlStr) return err } @@ -153,9 +160,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) } func (session *Session) isTableExist(tableName string) (bool, error) { - sqlStr, args := session.engine.dialect.TableCheckSql(tableName) - results, err := session.queryBytes(sqlStr, args...) - return len(results) > 0, err + return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) } // IsTableEmpty if table have any records @@ -182,17 +187,17 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.engine.dialect.GetIndexes(tableName) + indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName) if err != nil { return false, err } for _, index := range indexes { - if sliceEq(index.Cols, cols) { + if utils.SliceEq(index.Cols, cols) { if unique { - return index.Type == core.UniqueType, nil + return index.Type == schemas.UniqueType, nil } - return index.Type == core.IndexType, nil + return index.Type == schemas.IndexType, nil } } return false, nil @@ -200,21 +205,21 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) - sql, args := session.statement.genAddColumnStr(col) - _, err := session.exec(sql, args...) + sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) + _, err := session.exec(sql) return err } func (session *Session) addIndex(tableName, idxName string) error { index := session.statement.RefTable.Indexes[idxName] - sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) + sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } func (session *Session) addUnique(tableName, uqeName string) error { index := session.statement.RefTable.Indexes[uqeName] - sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) + sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } @@ -228,7 +233,7 @@ func (session *Session) Sync2(beans ...interface{}) error { defer session.Close() } - tables, err := engine.dialect.GetTables() + tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) if err != nil { return err } @@ -240,8 +245,8 @@ func (session *Session) Sync2(beans ...interface{}) error { }() for _, bean := range beans { - v := rValue(bean) - table, err := engine.mapType(v) + v := utils.ReflectValue(bean) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -253,7 +258,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } tbNameWithSchema := engine.tbNameWithSchema(tbName) - var oriTable *core.Table + var oriTable *schemas.Table for _, tb := range tables { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { oriTable = tb @@ -287,7 +292,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // check columns for _, col := range table.Columns() { - var oriCol *core.Column + var oriCol *schemas.Column for _, col2 := range oriTable.Columns() { if strings.EqualFold(col.Name, col2.Name) { oriCol = col2 @@ -298,7 +303,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // column is not exist on table if oriCol == nil { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) if err = session.addColumn(col.Name); err != nil { return err } @@ -306,27 +311,27 @@ func (session *Session) Sync2(beans ...interface{}) error { } err = nil - expectedType := engine.dialect.SqlType(col) - curType := engine.dialect.SqlType(oriCol) + expectedType := engine.dialect.SQLType(col) + curType := engine.dialect.SQLType(oriCol) if expectedType != curType { - if expectedType == core.Text && - strings.HasPrefix(curType, core.Varchar) { + if expectedType == schemas.Text && + strings.HasPrefix(curType, schemas.Varchar) { // currently only support mysql & postgres - if engine.dialect.DBType() == core.MYSQL || - engine.dialect.DBType() == core.POSTGRES { + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbNameWithSchema, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", tbNameWithSchema, col.Name, curType, expectedType) } - } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { - if engine.dialect.DBType() == core.MYSQL { + } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } else { @@ -335,12 +340,12 @@ func (session *Session) Sync2(beans ...interface{}) error { tbNameWithSchema, col.Name, curType, expectedType) } } - } else if expectedType == core.Varchar { - if engine.dialect.DBType() == core.MYSQL { + } else if expectedType == schemas.Varchar { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } @@ -348,7 +353,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if col.Default != oriCol.Default { switch { case col.IsAutoIncrement: // For autoincrement column, don't check default - case (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) && + case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): default: @@ -367,10 +372,10 @@ func (session *Session) Sync2(beans ...interface{}) error { } var foundIndexNames = make(map[string]bool) - var addedNames = make(map[string]*core.Index) + var addedNames = make(map[string]*schemas.Index) for name, index := range table.Indexes { - var oriIndex *core.Index + var oriIndex *schemas.Index for name2, index2 := range oriTable.Indexes { if index.Equal(index2) { oriIndex = index2 @@ -381,7 +386,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err @@ -397,7 +402,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err @@ -406,13 +411,13 @@ func (session *Session) Sync2(beans ...interface{}) error { } for name, index := range addedNames { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addUnique(tbNameWithSchema, name) - } else if index.Type == core.IndexType { + } else if index.Type == schemas.IndexType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addIndex(tbNameWithSchema, name) } if err != nil { @@ -430,3 +435,56 @@ func (session *Session) Sync2(beans ...interface{}) error { return nil } + +// ImportFile SQL DDL file +func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { + file, err := os.Open(ddlPath) + if err != nil { + return nil, err + } + defer file.Close() + return session.Import(file) +} + +// Import SQL DDL from io.Reader +func (session *Session) Import(r io.Reader) ([]sql.Result, error) { + var results []sql.Result + var lastError error + scanner := bufio.NewScanner(r) + + var inSingleQuote bool + semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + for i, b := range data { + if b == '\'' { + inSingleQuote = !inSingleQuote + } + if !inSingleQuote && b == ';' { + return i + 1, data[0:i], nil + } + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + + scanner.Split(semiColSpliter) + + for scanner.Scan() { + query := strings.Trim(scanner.Text(), " \t\n\r") + if len(query) > 0 { + result, err := session.Exec(query) + results = append(results, result) + if err != nil { + return nil, err + } + } + } + + return results, lastError +} diff --git a/session_stats.go b/session_stats.go index c2cac83..17d0a67 100644 --- a/session_stats.go +++ b/session_stats.go @@ -17,17 +17,9 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { defer session.Close() } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - sqlStr, args, err = session.statement.genCountSQL(bean...) - if err != nil { - return 0, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenCountSQL(bean...) + if err != nil { + return 0, err } var total int64 @@ -50,21 +42,12 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st return errors.New("need a pointer to a variable") } - var isSlice = v.Elem().Kind() == reflect.Slice - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) - if err != nil { - return err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenSumSQL(bean, columnNames...) + if err != nil { + return err } - if isSlice { + if v.Elem().Kind() == reflect.Slice { err = session.queryRow(sqlStr, args...).ScanSlice(res) } else { err = session.queryRow(sqlStr, args...).Scan(res) diff --git a/session_sum_test.go b/session_sum_test.go deleted file mode 100644 index 12f61cc..0000000 --- a/session_sum_test.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xormplus/builder" -) - -func isFloatEq(i, j float64, precision int) bool { - return fmt.Sprintf("%."+strconv.Itoa(precision)+"f", i) == fmt.Sprintf("%."+strconv.Itoa(precision)+"f", j) -} - -func TestSum(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type SumStruct struct { - Int int - Float float32 - } - - var ( - cases = []SumStruct{ - {1, 6.2}, - {2, 5.3}, - {92, -0.2}, - } - ) - - var i int - var f float32 - for _, v := range cases { - i += v.Int - f += v.Float - } - - assert.NoError(t, testEngine.Sync2(new(SumStruct))) - - cnt, err := testEngine.Insert(cases) - assert.NoError(t, err) - assert.EqualValues(t, 3, cnt) - - colInt := testEngine.ColumnMapper.Obj2Table("Int") - colFloat := testEngine.ColumnMapper.Obj2Table("Float") - - sumInt, err := testEngine.Sum(new(SumStruct), colInt) - assert.NoError(t, err) - assert.EqualValues(t, int(sumInt), i) - - sumFloat, err := testEngine.Sum(new(SumStruct), colFloat) - assert.NoError(t, err) - assert.Condition(t, func() bool { - return isFloatEq(sumFloat, float64(f), 2) - }) - - sums, err := testEngine.Sums(new(SumStruct), colInt, colFloat) - assert.NoError(t, err) - assert.EqualValues(t, 2, len(sums)) - assert.EqualValues(t, i, int(sums[0])) - assert.Condition(t, func() bool { - return isFloatEq(sums[1], float64(f), 2) - }) - - sumsInt, err := testEngine.SumsInt(new(SumStruct), colInt) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(sumsInt)) - assert.EqualValues(t, i, int(sumsInt[0])) -} - -func TestSumCustomColumn(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type SumStruct struct { - Int int - Float float32 - } - - var ( - cases = []SumStruct{ - {1, 6.2}, - {2, 5.3}, - {92, -0.2}, - } - ) - - assert.NoError(t, testEngine.Sync2(new(SumStruct))) - - cnt, err := testEngine.Insert(cases) - assert.NoError(t, err) - assert.EqualValues(t, 3, cnt) - - sumInt, err := testEngine.Sum(new(SumStruct), - "CASE WHEN `int` <= 2 THEN `int` ELSE 0 END") - assert.NoError(t, err) - assert.EqualValues(t, 3, int(sumInt)) -} - -func TestCount(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type UserinfoCount struct { - Departname string - } - assert.NoError(t, testEngine.Sync2(new(UserinfoCount))) - - colName := testEngine.ColumnMapper.Obj2Table("Departname") - var cond builder.Cond = builder.Eq{ - "`" + colName + "`": "dev", - } - - total, err := testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 0, total) - - cnt, err := testEngine.Insert(&UserinfoCount{ - Departname: "dev", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - total, err = testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 1, total) -} - -func TestSQLCount(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type UserinfoCount2 struct { - Id int64 - Departname string - } - - type UserinfoBooks struct { - Id int64 - Pid int64 - IsOpen bool - } - - assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - - total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). - Count() - assert.NoError(t, err) - assert.EqualValues(t, 0, total) -} diff --git a/session_tx.go b/session_tx.go index ee3d473..94ed6f4 100644 --- a/session_tx.go +++ b/session_tx.go @@ -4,6 +4,12 @@ package xorm +import ( + "time" + + "github.com/xormplus/xorm/log" +) + // Begin a transaction func (session *Session) Begin() error { if session.isAutoCommit { @@ -14,6 +20,7 @@ func (session *Session) Begin() error { session.isAutoCommit = false session.isCommitedOrRollbacked = false session.tx = tx + session.saveLastSQL("BEGIN TRANSACTION") } return nil @@ -22,10 +29,28 @@ func (session *Session) Begin() error { // Rollback When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.isAutoCommit && !session.isCommitedOrRollbacked { - session.saveLastSQL(session.engine.dialect.RollBackStr()) + session.saveLastSQL("ROLL BACK") session.isCommitedOrRollbacked = true session.isAutoCommit = true - return session.tx.Rollback() + + start := time.Now() + needSQL := session.DB().NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + }) + } + err := session.tx.Rollback() + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return err } return nil } @@ -36,48 +61,67 @@ func (session *Session) Commit() error { session.saveLastSQL("COMMIT") session.isCommitedOrRollbacked = true session.isAutoCommit = true - var err error - if err = session.tx.Commit(); err == nil { - // handle processors after tx committed - closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { - if closuresPtr != nil { - for _, closure := range *closuresPtr { - closure(bean) - } + + start := time.Now() + needSQL := session.DB().NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + }) + } + err := session.tx.Commit() + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + + if err != nil { + return err + } + + // handle processors after tx committed + closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { + if closuresPtr != nil { + for _, closure := range *closuresPtr { + closure(bean) } } + } - for bean, closuresPtr := range session.afterInsertBeans { - closureCallFunc(closuresPtr, bean) + for bean, closuresPtr := range session.afterInsertBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() } - for bean, closuresPtr := range session.afterUpdateBeans { - closureCallFunc(closuresPtr, bean) + } + for bean, closuresPtr := range session.afterUpdateBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - processor.AfterUpdate() - } + if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { + processor.AfterUpdate() } - for bean, closuresPtr := range session.afterDeleteBeans { - closureCallFunc(closuresPtr, bean) + } + for bean, closuresPtr := range session.afterDeleteBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() } - cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { - if len(*slices) > 0 { - *slices = make(map[interface{}]*[]func(interface{}), 0) - } + } + cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { + if len(*slices) > 0 { + *slices = make(map[interface{}]*[]func(interface{}), 0) } - cleanUpFunc(&session.afterInsertBeans) - cleanUpFunc(&session.afterUpdateBeans) - cleanUpFunc(&session.afterDeleteBeans) } - return err + cleanUpFunc(&session.afterInsertBeans) + cleanUpFunc(&session.afterUpdateBeans) + cleanUpFunc(&session.afterDeleteBeans) } return nil } diff --git a/session_tx_plus.go b/session_tx_plus.go index 30e6705..8d631f1 100644 --- a/session_tx_plus.go +++ b/session_tx_plus.go @@ -3,7 +3,8 @@ package xorm import ( "sync" - "github.com/xormplus/core" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) const ( @@ -150,11 +151,11 @@ func (transaction *Transaction) BeginTrans() error { } } else { transaction.isNested = true - dbtype := transaction.txSession.engine.Dialect().DBType() - if dbtype == core.MSSQL { - transaction.savePointID = "xorm" + NewShortUUID().String() + dbtype := transaction.txSession.engine.Dialect().URI().DBType + if dbtype == schemas.MSSQL { + transaction.savePointID = "xorm" + utils.NewShortUUID().String() } else { - transaction.savePointID = "xorm" + NewV1().WithoutDashString() + transaction.savePointID = "xorm" + utils.NewV1().WithoutDashString() } if err := transaction.SavePoint(transaction.savePointID); err != nil { @@ -377,8 +378,8 @@ func (transaction *Transaction) SavePoint(savePointID string) error { } var lastSQL string - dbtype := transaction.txSession.engine.Dialect().DBType() - if dbtype == core.MSSQL { + dbtype := transaction.txSession.engine.Dialect().URI().DBType + if dbtype == schemas.MSSQL { lastSQL = "save tran " + savePointID } else { lastSQL = "SAVEPOINT " + savePointID + ";" @@ -398,8 +399,8 @@ func (transaction *Transaction) RollbackToSavePoint(savePointID string) error { } var lastSQL string - dbtype := transaction.txSession.engine.Dialect().DBType() - if dbtype == core.MSSQL { + dbtype := transaction.txSession.engine.Dialect().URI().DBType + if dbtype == schemas.MSSQL { lastSQL = "rollback tran " + savePointID } else { lastSQL = "ROLLBACK TO SAVEPOINT " + transaction.savePointID + ";" diff --git a/session_update.go b/session_update.go index 6cd7f16..38f48bd 100644 --- a/session_update.go +++ b/session_update.go @@ -12,23 +12,25 @@ import ( "strings" "github.com/xormplus/builder" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/internal/utils" + "github.com/xormplus/xorm/schemas" ) -func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { +func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { return ErrCacheFailed } - oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) + oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) if newsql == "" { return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql, session.engine.dialect, table) + newsql = filter.Do(newsql) } - session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) + session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) var nStart int if len(args) > 0 { @@ -40,9 +42,9 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } } - cacher := session.engine.getCacher(tableName) - session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) - ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) + cacher := session.engine.GetCacher(tableName) + session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) if err != nil { @@ -50,14 +52,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } defer rows.Close() - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) for rows.Next() { var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { if col.SQLType.IsNumeric() { n, err := strconv.ParseInt(res[i], 10, 64) @@ -74,7 +76,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, ids = append(ids, pk) } - session.engine.logger.Debug("[cacheUpdate] find updated id", ids) + session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) cacher.DelIds(tableName, genSqlKey(newsql, args)) @@ -86,12 +88,12 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, return err } if bean := cacher.GetBean(tableName, sid); bean != nil { - sqls := splitNNoCase(sqlStr, "where", 2) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } - sqls = splitNNoCase(sqls[0], "set", 2) + sqls = utils.SplitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { return ErrCacheFailed } @@ -101,38 +103,32 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - // treat quote prefix, suffix and '`' as quotes - quotes := append(strings.Split(session.engine.Quote(""), ""), "`") - if strings.ContainsAny(colName, strings.Join(quotes, "")) { - colName = strings.TrimSpace(eraseAny(colName, quotes...)) - } else { - session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) - return ErrCacheFailed - } + colName = session.engine.dialect.Quoter().Trim(colName) + colName = schemas.CommonQuoter.Trim(colName) if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else { - session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.statement.checkVersion { + session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface()) + if col.IsVersion && session.statement.CheckVersion { session.incrVersionFieldValue(fieldValue) } else { fieldValue.Set(reflect.ValueOf(args[idx])) } } } else { - session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", + session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's", colName, table.Name) } } - session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) + session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean) cacher.PutBean(tableName, sid, bean) } } - session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) + session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -148,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - v := rValue(bean) + v := utils.ReflectValue(bean) t := v.Type() var colNames []string @@ -172,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -180,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrTableNotFound } - if session.statement.ColumnStr == "" { - colNames, args = session.statement.buildUpdates(bean, false, false, + if session.statement.ColumnStr() == "" { + colNames, args, err = session.statement.BuildUpdates(v, false, false, false, false, true) } else { colNames, args, err = session.genUpdateColumns(bean) - if err != nil { - return 0, err - } + } + if err != nil { + return 0, err } } else if isMap { colNames = make([]string, 0) @@ -205,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if !session.statement.columnMap.contain(table.Updated) && - !session.statement.omitColumnMap.contain(table.Updated) { + if !session.statement.ColumnMap.Contain(table.Updated) && + !session.statement.OmitColumnMap.Contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -223,39 +219,45 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.incrColumns - for i, colName := range incColumns.colNames { + incColumns := session.statement.IncrColumns + for i, colName := range incColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.args[i]) + args = append(args, incColumns.Args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.decrColumns - for i, colName := range decColumns.colNames { + decColumns := session.statement.DecrColumns + for i, colName := range decColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.args[i]) + args = append(args, decColumns.Args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.exprColumns - for i, colName := range exprColumns.colNames { - switch tp := exprColumns.args[i].(type) { + exprColumns := session.statement.ExprColumns + for i, colName := range exprColumns.ColNames { + switch tp := exprColumns.Args[i].(type) { case string: - colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) + if len(tp) == 0 { + tp = "''" + } + colNames = append(colNames, session.engine.Quote(colName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) + subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") + colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")") args = append(args, subArgs...) + default: + colNames = append(colNames, session.engine.Quote(colName)+"=?") + args = append(args, exprColumns.Args[i]) } } - if err = session.statement.processIDParam(); err != nil { + if err = session.statement.ProcessIDParam(); err != nil { return 0, err } var autoCond builder.Cond - if !session.statement.noAutoCondition { + if !session.statement.NoAutoCondition { condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { @@ -268,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if k == reflect.Struct { var err error - autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -280,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - autoCond1 := session.engine.CondDeleted(col) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond1 := session.statement.CondDeleted(col) if autoCond == nil { autoCond = autoCond1 @@ -292,18 +294,17 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - st := &session.statement + st := session.statement var ( sqlStr string condArgs []interface{} condSQL string - cond = session.statement.cond.And(autoCond) + cond = session.statement.Conds().And(autoCond) - doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.checkVersion) + doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) verValue *reflect.Value ) - if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) if err != nil { @@ -316,7 +317,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - condSQL, condArgs, err = builder.ToSQL(cond) + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -334,24 +339,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if st.Engine.dialect.DBType() == core.MYSQL { + switch session.engine.dialect.URI().DBType { + case schemas.MYSQL: condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if st.Engine.dialect.DBType() == core.SQLITE { + case schemas.SQLITE: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.POSTGRES { + case schemas.POSTGRES: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -359,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.MSSQL { - if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && - table != nil && len(table.PrimaryKeys) == 1 { + case schemas.MSSQL: + if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -379,15 +384,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") - } - var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { - switch session.engine.dialect.DBType() { - case core.MSSQL: + switch session.engine.dialect.URI().DBType { + case schemas.MSSQL: fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) tableAlias = session.statement.TableAlias default: @@ -411,9 +412,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { - //session.cacheUpdate(table, tableName, sqlStr, args...) - session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { + // session.cacheUpdate(table, tableName, sqlStr, args...) + session.engine.logger.Debugf("[cache] clear table: %v", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) } @@ -424,7 +425,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 closure(bean) } if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.engine.logger.Debug("[event]", tableName, " has after update processor") + session.engine.logger.Debugf("[event] %v has after update processor", tableName) processor.AfterUpdate() } } else { @@ -458,11 +459,11 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac for _, col := range table.Columns() { if !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } } - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -472,46 +473,30 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } - if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated { + if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated { continue } // if only update specify columns - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZeroValue(fieldValue) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -527,10 +512,10 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } diff --git a/sql_executor.go b/sql_executor.go index 525dad7..6877636 100644 --- a/sql_executor.go +++ b/sql_executor.go @@ -4,6 +4,8 @@ import ( "database/sql" "strings" "time" + + "github.com/xormplus/xorm/internal/utils" ) type SqlsExecutor struct { @@ -41,10 +43,10 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str if sqlsExecutor.parmas == nil { switch sqlCmd { case "select": - model_1_results = sqlsExecutor.session.Sql(sqlStr).Query() + model_1_results = sqlsExecutor.session.SQL(sqlStr).Query() sqlModel = 1 case "insert", "delete", "update", "create", "drop": - model_2_results, err = sqlsExecutor.session.Sql(sqlStr).Execute() + model_2_results, err = sqlsExecutor.session.SQL(sqlStr).Execute() sqlModel = 2 default: sqlModel = 3 @@ -54,7 +56,7 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str case []map[string]interface{}: parmaMap, _ := sqlsExecutor.parmas.([]map[string]interface{}) - key := NewV4().String() + time.Now().String() + key := utils.NewV4().String() + time.Now().String() sqlsExecutor.session.engine.AddSql(key, sqlStr) switch sqlCmd { case "select": @@ -71,7 +73,7 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str case map[string]interface{}: parmaMap, _ := sqlsExecutor.parmas.(map[string]interface{}) - key := NewV4().String() + time.Now().String() + key := utils.NewV4().String() + time.Now().String() sqlsExecutor.session.engine.AddSql(key, sqlStr) switch sqlCmd { case "select": @@ -165,10 +167,10 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str sqlCmd := strings.ToLower(strings.Split(sqlStr, " ")[0]) switch sqlCmd { case "select": - model_1_results = sqlsExecutor.session.Sql(sqlStr).Query() + model_1_results = sqlsExecutor.session.SQL(sqlStr).Query() sqlModel = 1 case "insert", "delete", "update", "create", "drop": - model_2_results, err = sqlsExecutor.session.Sql(sqlStr).Execute() + model_2_results, err = sqlsExecutor.session.SQL(sqlStr).Execute() sqlModel = 2 default: sqlModel = 3 @@ -247,16 +249,16 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str if parmaSlice[i] == nil { switch sqlCmd { case "select": - model_1_results = sqlsExecutor.session.Sql(sqlStr).Query() + model_1_results = sqlsExecutor.session.SQL(sqlStr).Query() sqlModel = 1 case "insert", "delete", "update", "create", "drop": - model_2_results, err = sqlsExecutor.session.Sql(sqlStr).Execute() + model_2_results, err = sqlsExecutor.session.SQL(sqlStr).Execute() sqlModel = 2 default: sqlModel = 3 } } else { - key := NewV4().String() + time.Now().String() + key := utils.NewV4().String() + time.Now().String() sqlsExecutor.session.engine.AddSql(key, sqlStr) switch sqlCmd { case "select": @@ -347,11 +349,11 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str switch sqlCmd { case "select": sqlModel = 1 - model_1_results = sqlsExecutor.session.Sql(sqlStr).Query() + model_1_results = sqlsExecutor.session.SQL(sqlStr).Query() case "insert", "delete", "update", "create", "drop": sqlModel = 2 - model_2_results, err = sqlsExecutor.session.Sql(sqlStr).Execute() + model_2_results, err = sqlsExecutor.session.SQL(sqlStr).Execute() default: sqlModel = 3 @@ -432,17 +434,17 @@ func (sqlsExecutor *SqlsExecutor) Execute() ([][]map[string]interface{}, map[str switch sqlCmd { case "select": sqlModel = 1 - model_1_results = sqlsExecutor.session.Sql(sqlStr).Query() + model_1_results = sqlsExecutor.session.SQL(sqlStr).Query() case "insert", "delete", "update", "create", "drop": sqlModel = 2 - model_2_results, err = sqlsExecutor.session.Sql(sqlStr).Execute() + model_2_results, err = sqlsExecutor.session.SQL(sqlStr).Execute() default: sqlModel = 3 } } else { - key := NewV4().String() + time.Now().String() + key := utils.NewV4().String() + time.Now().String() sqlsExecutor.session.engine.AddSql(key, sqlStr) parmaMap := parmasMap[k] switch sqlCmd { diff --git a/statement.go b/statement.go deleted file mode 100644 index 37caed4..0000000 --- a/statement.go +++ /dev/null @@ -1,1291 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "github.com/xormplus/builder" - "github.com/xormplus/core" -) - -// Statement save all the sql info for executing SQL -type Statement struct { - RefTable *core.Table - Engine *Engine - Start int - LimitN *int - idParam *core.PK - OrderStr string - JoinStr string - joinArgs []interface{} - GroupByStr string - HavingStr string - ColumnStr string - selectStr string - useAllCols bool - OmitStr string - AltTableName string - tableName string - RawSQL string - RawParams []interface{} - UseCascade bool - UseAutoJoin bool - StoreEngine string - Charset string - UseCache bool - UseAutoTime bool - noAutoCondition bool - IsDistinct bool - IsForUpdate bool - TableAlias string - allUseBool bool - checkVersion bool - unscoped bool - columnMap columnMap - omitColumnMap columnMap - mustColumnMap map[string]bool - nullableMap map[string]bool - incrColumns exprParams - decrColumns exprParams - exprColumns exprParams - cond builder.Cond - bufferSize int - context ContextCache - lastError error -} - -// Init reset all the statement's fields -func (statement *Statement) Init() { - statement.RefTable = nil - statement.Start = 0 - statement.LimitN = nil - statement.OrderStr = "" - statement.UseCascade = true - statement.JoinStr = "" - statement.joinArgs = make([]interface{}, 0) - statement.GroupByStr = "" - statement.HavingStr = "" - statement.ColumnStr = "" - statement.OmitStr = "" - statement.columnMap = columnMap{} - statement.omitColumnMap = columnMap{} - statement.AltTableName = "" - statement.tableName = "" - statement.idParam = nil - statement.RawSQL = "" - statement.RawParams = make([]interface{}, 0) - statement.UseCache = true - statement.UseAutoTime = true - statement.noAutoCondition = false - statement.IsDistinct = false - statement.IsForUpdate = false - statement.TableAlias = "" - statement.selectStr = "" - statement.allUseBool = false - statement.useAllCols = false - statement.mustColumnMap = make(map[string]bool) - statement.nullableMap = make(map[string]bool) - statement.checkVersion = true - statement.unscoped = false - statement.incrColumns = exprParams{} - statement.decrColumns = exprParams{} - statement.exprColumns = exprParams{} - statement.cond = builder.NewCond() - statement.bufferSize = 0 - statement.context = nil - statement.lastError = nil -} - -// NoAutoCondition if you do not want convert bean's field as query condition, then use this function -func (statement *Statement) NoAutoCondition(no ...bool) *Statement { - statement.noAutoCondition = true - if len(no) > 0 { - statement.noAutoCondition = no[0] - } - return statement -} - -// Alias set the table alias -func (statement *Statement) Alias(alias string) *Statement { - statement.TableAlias = alias - return statement -} - -// SQL adds raw sql statement -func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case (*builder.Builder): - var err error - statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() - if err != nil { - statement.lastError = err - } - case string: - statement.RawSQL = query.(string) - statement.RawParams = args - default: - statement.lastError = ErrUnSupportedSQLType - } - - return statement -} - -// Where add Where statement -func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { - return statement.And(query, args...) -} - -// And add Where & and statement -func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - isExpr := false - var cargs []interface{} - for i, _ := range args { - if _, ok := args[i].(sqlExpr); ok { - isExpr = true - } - cargs = append(cargs, args[i]) - } - if isExpr { - sqlStr, _ := ConvertToBoundSQL(query.(string), cargs) - cond := builder.Expr(sqlStr) - statement.cond = statement.cond.And(cond) - } else { - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.And(cond) - } - case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.Engine.Quote(k)] = v - } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.And(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.And(vv) - } - } - default: - statement.lastError = ErrConditionType - } - - return statement -} - -// Or add Where & Or statement -func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - isExpr := false - var cargs []interface{} - for i, _ := range args { - if _, ok := args[i].(sqlExpr); ok { - isExpr = true - } - cargs = append(cargs, args[i]) - } - if isExpr { - sqlStr, _ := ConvertToBoundSQL(query.(string), cargs) - cond := builder.Expr(sqlStr) - statement.cond = statement.cond.Or(cond) - } else { - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.Or(cond) - } - case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) - statement.cond = statement.cond.Or(cond) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.Or(vv) - } - } - default: - // TODO: not support condition type - } - return statement -} - -// In generate "Where column IN (?) " statement -func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.Engine.Quote(column), args...) - statement.cond = statement.cond.And(in) - return statement -} - -// NotIn generate "Where column NOT IN (?) " statement -func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.Engine.Quote(column), args...) - statement.cond = statement.cond.And(notIn) - return statement -} - -func (statement *Statement) setRefValue(v reflect.Value) error { - var err error - statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) - if err != nil { - return err - } - statement.tableName = statement.Engine.TableName(v, true) - return nil -} - -func (statement *Statement) setRefBean(bean interface{}) error { - var err error - statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) - if err != nil { - return err - } - statement.tableName = statement.Engine.TableName(bean, true) - return nil -} - -// Auto generating update columnes and values according a struct -func (statement *Statement) buildUpdates(bean interface{}, - includeVersion, includeUpdated, includeNil, - includeAutoIncr, update bool) ([]string, []interface{}) { - engine := statement.Engine - table := statement.RefTable - allUseBool := statement.allUseBool - useAllCols := statement.useAllCols - mustColumnMap := statement.mustColumnMap - nullableMap := statement.nullableMap - columnMap := statement.columnMap - omitColumnMap := statement.omitColumnMap - unscoped := statement.unscoped - - var colNames = make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if col.IsCreated && !columnMap.contain(col.Name) { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - if col.IsDeleted && !unscoped { - continue - } - if omitColumnMap.contain(col.Name) { - continue - } - if len(columnMap) > 0 && !columnMap.contain(col.Name) { - continue - } - - if col.MapType == core.ONLYFROMDB { - continue - } - - if statement.incrColumns.isColExist(col.Name) { - continue - } else if statement.decrColumns.isColExist(col.Name) { - continue - } else if statement.exprColumns.isColExist(col.Name) { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - fieldValue := *fieldValuePtr - fieldType := reflect.TypeOf(fieldValue.Interface()) - if fieldType == nil { - continue - } - - requiredField := useAllCols - includeNil := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if b, ok := getFlagForColumn(nullableMap, col); ok { - if b && col.Nullable && isZero(fieldValue.Interface()) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - fieldType = reflect.TypeOf(fieldValue.Interface()) - includeNil = true - } - } - - var val interface{} - - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - } - - if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = nulType.Value() - } else { - if !col.SQLType.IsJson() { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) { - val = pkField.Interface() - } else { - continue - } - } else { - // TODO: how to handler? - panic("not supported") - } - } else { - val = fieldValue.Interface() - } - } else { - // Blank struct could not be as update data - if requiredField || !isStructZero(fieldValue) { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) - } - if col.SQLType.IsText() { - val = string(bytes) - } else if col.SQLType.IsBlob() { - val = bytes - } - } else { - continue - } - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if !requiredField { - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldType.Kind() == reflect.Array { - if isArrayValueZero(fieldValue) { - continue - } - } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if fieldType.Kind() == reflect.Slice && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else if fieldType.Kind() == reflect.Array && - fieldType.Elem().Kind() == reflect.Uint8 { - val = fieldValue.Slice(0, 0).Interface() - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - APPEND: - args = append(args, val) - if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { - continue - } - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) - } - - return colNames, args -} - -func (statement *Statement) needTableName() bool { - return len(statement.JoinStr) > 0 -} - -func (statement *Statement) colName(col *core.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias - } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) - } - return statement.Engine.Quote(col.Name) -} - -// TableName return current tableName -func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName - } - - return statement.tableName -} - -// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" -func (statement *Statement) ID(id interface{}) *Statement { - idValue := reflect.ValueOf(id) - idType := reflect.TypeOf(idValue.Interface()) - - switch idType { - case ptrPkType: - if pkPtr, ok := (id).(*core.PK); ok { - statement.idParam = pkPtr - return statement - } - case pkType: - if pk, ok := (id).(core.PK); ok { - statement.idParam = &pk - return statement - } - } - - switch idType.Kind() { - case reflect.String: - statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()} - return statement - } - - statement.idParam = &core.PK{id} - return statement -} - -// Incr Generate "Update ... Set column = column + arg" statement -func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { - if len(arg) > 0 { - statement.incrColumns.addParam(column, arg[0]) - } else { - statement.incrColumns.addParam(column, 1) - } - return statement -} - -// Decr Generate "Update ... Set column = column - arg" statement -func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { - if len(arg) > 0 { - statement.decrColumns.addParam(column, arg[0]) - } else { - statement.decrColumns.addParam(column, 1) - } - return statement -} - -// SetExpr Generate "Update ... Set column = {expression}" statement -func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { - statement.exprColumns.addParam(column, expression) - return statement -} - -func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { - newColumns := make([]string, 0) - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - for _, col := range columns { - newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...))) - } - return newColumns -} - -func (statement *Statement) colmap2NewColsWithQuote() []string { - newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) - copy(newColumns, statement.columnMap) - for i := 0; i < len(statement.columnMap); i++ { - newColumns[i] = statement.Engine.Quote(newColumns[i]) - } - return newColumns -} - -// Distinct generates "DISTINCT col1, col2 " statement -func (statement *Statement) Distinct(columns ...string) *Statement { - statement.IsDistinct = true - statement.Cols(columns...) - return statement -} - -// ForUpdate generates "SELECT ... FOR UPDATE" statement -func (statement *Statement) ForUpdate() *Statement { - statement.IsForUpdate = true - return statement -} - -// Select replace select -func (statement *Statement) Select(str string) *Statement { - statement.selectStr = str - return statement -} - -// Cols generate "col1, col2" statement -func (statement *Statement) Cols(columns ...string) *Statement { - cols := col2NewCols(columns...) - for _, nc := range cols { - statement.columnMap.add(nc) - } - - newColumns := statement.colmap2NewColsWithQuote() - - statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) - return statement -} - -// AllCols update use only: update all columns -func (statement *Statement) AllCols() *Statement { - statement.useAllCols = true - return statement -} - -// MustCols update use only: must update columns -func (statement *Statement) MustCols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.mustColumnMap[strings.ToLower(nc)] = true - } - return statement -} - -// UseBool indicates that use bool fields as update contents and query contiditions -func (statement *Statement) UseBool(columns ...string) *Statement { - if len(columns) > 0 { - statement.MustCols(columns...) - } else { - statement.allUseBool = true - } - return statement -} - -// Omit do not use the columns -func (statement *Statement) Omit(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.omitColumnMap = append(statement.omitColumnMap, nc) - } - statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) -} - -// Nullable Update use only: update columns to null when value is nullable and zero-value -func (statement *Statement) Nullable(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.nullableMap[strings.ToLower(nc)] = true - } -} - -// Top generate LIMIT limit statement -func (statement *Statement) Top(limit int) *Statement { - statement.Limit(limit) - return statement -} - -// Limit generate LIMIT start, limit statement -func (statement *Statement) Limit(limit int, start ...int) *Statement { - if limit > 0 { - statement.LimitN = &limit - } - if len(start) > 0 { - statement.Start = start[0] - } - return statement -} - -// OrderBy generate "Order By order" statement -func (statement *Statement) OrderBy(order string) *Statement { - if len(statement.OrderStr) > 0 { - statement.OrderStr += ", " - } - statement.OrderStr += order - return statement -} - -// Desc generate `ORDER BY xx DESC` -func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) - statement.OrderStr = buf.String() - return statement -} - -// Asc provide asc order by query condition, the input parameters are columns. -func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) - statement.OrderStr = buf.String() - return statement -} - -// Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - v := rValue(tableNameOrBean) - t := v.Type() - if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } - } - - statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) - return statement -} - -// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.lastError = err - return statement - } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.lastError = err - return statement - } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - default: - tbName := statement.Engine.TableName(tablename, true) - fmt.Fprintf(&buf, "%s ON %v", tbName, condition) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) - return statement -} - -// GroupBy generate "Group By keys" statement -func (statement *Statement) GroupBy(keys string) *Statement { - statement.GroupByStr = keys - return statement -} - -// Having generate "Having conditions" statement -func (statement *Statement) Having(conditions string) *Statement { - statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) - return statement -} - -// Unscoped always disable struct tag "deleted" -func (statement *Statement) Unscoped() *Statement { - statement.unscoped = true - return statement -} - -func (statement *Statement) genColumnStr() string { - if statement.RefTable == nil { - return "" - } - - var buf strings.Builder - columns := statement.RefTable.Columns() - - for _, col := range columns { - if statement.omitColumnMap.contain(col.Name) { - continue - } - - if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { - continue - } - - if col.MapType == core.ONLYTODB { - continue - } - - if buf.Len() != 0 { - buf.WriteString(", ") - } - - if statement.JoinStr != "" { - if statement.TableAlias != "" { - buf.WriteString(statement.TableAlias) - } else { - buf.WriteString(statement.TableName()) - } - - buf.WriteString(".") - } - - statement.Engine.QuoteTo(&buf, col.Name) - } - - return buf.String() -} - -func (statement *Statement) genCreateTableSQL() string { - return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(), - statement.StoreEngine, statement.Charset) -} - -func (statement *Statement) genIndexSQL() []string { - var sqls []string - tbName := statement.TableName() - for _, index := range statement.RefTable.Indexes { - if index.Type == core.IndexType { - sql := statement.Engine.dialect.CreateIndexSql(tbName, index) - if sql != "" { - sqls = append(sqls, sql) - } - } - } - return sqls -} - -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - -func (statement *Statement) genUniqueSQL() []string { - var sqls []string - tbName := statement.TableName() - for _, index := range statement.RefTable.Indexes { - if index.Type == core.UniqueType { - sql := statement.Engine.dialect.CreateIndexSql(tbName, index) - sqls = append(sqls, sql) - } - } - return sqls -} - -func (statement *Statement) genDelIndexSQL() []string { - var sqls []string - tbName := statement.TableName() - idxPrefixName := strings.Replace(tbName, `"`, "", -1) - idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) - for idxName, index := range statement.RefTable.Indexes { - var rIdxName string - if index.Type == core.UniqueType { - rIdxName = uniqueName(idxPrefixName, idxName) - } else if index.Type == core.IndexType { - rIdxName = indexName(idxPrefixName, idxName) - } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) - if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) - } - sqls = append(sqls, sql) - } - return sqls -} - -func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { - quote := statement.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), - col.String(statement.Engine.dialect)) - if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ";" - return sql, []interface{}{} -} - -func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) -} - -func (statement *Statement) mergeConds(bean interface{}) error { - if !statement.noAutoCondition { - var addedTableName = (len(statement.JoinStr) > 0) - autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) - if err != nil { - return err - } - statement.cond = statement.cond.And(autoCond) - } - - if err := statement.processIDParam(); err != nil { - return err - } - return nil -} - -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - - return builder.ToSQL(statement.cond) -} - -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { - v := rValue(bean) - isStruct := v.Kind() == reflect.Struct - if isStruct { - statement.setRefBean(bean) - } - - var columnStr = statement.ColumnStr - if len(statement.selectStr) > 0 { - columnStr = statement.selectStr - } else { - // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) - } - } - } - } - - if len(columnStr) == 0 { - columnStr = "*" - } - - if isStruct { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - } else { - if err := statement.processIDParam(); err != nil { - return "", nil, err - } - } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { - var condSQL string - var condArgs []interface{} - var err error - if len(beans) > 0 { - statement.setRefBean(beans[0]) - condSQL, condArgs, err = statement.genConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err - } - - var selectSQL = statement.selectStr - if len(selectSQL) <= 0 { - if statement.IsDistinct { - selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr) - } else { - selectSQL = "count(*)" - } - } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefBean(bean) - - var sumStrs = make([]string, 0, len(columns)) - for _, colName := range columns { - if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { - colName = statement.Engine.Quote(colName) - } - sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) - } - sumSelect := strings.Join(sumStrs, ", ") - - condSQL, condArgs, err := statement.genConds(bean) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { - var ( - distinct string - dialect = statement.Engine.Dialect() - quote = statement.Engine.Quote - fromStr = " FROM " - top, mssqlCondi, whereStr string - ) - if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { - distinct = "DISTINCT " - } - if len(condSQL) > 0 { - whereStr = " WHERE " + condSQL - } - - if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { - fromStr += statement.TableName() - } else { - fromStr += quote(statement.TableName()) - } - - if statement.TableAlias != "" { - if dialect.DBType() == core.ORACLE { - fromStr += " " + quote(statement.TableAlias) - } else { - fromStr += " AS " + quote(statement.TableAlias) - } - } - if statement.JoinStr != "" { - fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) - } - - pLimitN := statement.LimitN - if dialect.DBType() == core.MSSQL { - if pLimitN != nil { - LimitNValue := *pLimitN - top = fmt.Sprintf("TOP %d ", LimitNValue) - } - if statement.Start > 0 { - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.needTableName() { - if len(statement.TableAlias) > 0 { - column = statement.TableAlias + "." + column - } else { - column = statement.TableName() + "." + column - } - } - - var orderStr string - if needOrderBy && len(statement.OrderStr) > 0 { - orderStr = " ORDER BY " + statement.OrderStr - } - - var groupStr string - if len(statement.GroupByStr) > 0 { - groupStr = " GROUP BY " + statement.GroupByStr - } - mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", - column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) - } - } - - var buf strings.Builder - fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) - if len(mssqlCondi) > 0 { - if len(whereStr) > 0 { - fmt.Fprint(&buf, " AND ", mssqlCondi) - } else { - fmt.Fprint(&buf, " WHERE ", mssqlCondi) - } - } - - if statement.GroupByStr != "" { - fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) - } - if statement.HavingStr != "" { - fmt.Fprint(&buf, " ", statement.HavingStr) - } - if needOrderBy && statement.OrderStr != "" { - fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) - } - if needLimit { - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { - if statement.Start > 0 { - if pLimitN != nil { - fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) - } else { - fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) - } - } else if pLimitN != nil { - fmt.Fprint(&buf, " LIMIT ", *pLimitN) - } - } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || pLimitN != nil { - oldString := buf.String() - buf.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) - } - } - } - if statement.IsForUpdate { - return dialect.ForUpdateSql(buf.String()), nil - } - - return buf.String(), nil -} - -func (statement *Statement) processIDParam() error { - if statement.idParam == nil || statement.RefTable == nil { - return nil - } - - if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { - return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", - len(statement.RefTable.PrimaryKeys), - len(*statement.idParam), - ) - } - - for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } - return nil -} - -func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { - var colnames = make([]string, len(cols)) - for i, col := range cols { - if includeTableName { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) - } else { - colnames[i] = statement.Engine.Quote(col.Name) - } - } - return strings.Join(colnames, ", ") -} - -func (statement *Statement) convertIDSQL(sqlStr string) string { - if statement.RefTable != nil { - cols := statement.RefTable.PKColumns() - if len(cols) == 0 { - return "" - } - - colstrs := statement.joinColumns(cols, false) - sqls := splitNNoCase(sqlStr, " from ", 2) - if len(sqls) != 2 { - return "" - } - - var top string - pLimitN := statement.LimitN - if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL { - top = fmt.Sprintf("TOP %d ", *pLimitN) - } - - newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) - return newsql - } - return "" -} - -func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { - if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { - return "", "" - } - - colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) - sqls := splitNNoCase(sqlStr, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.Engine.Quote(statement.TableName())) - } - return "", "" - } - - var whereStr = sqls[1] - - // TODO: for postgres only, if any other database? - var paraStr string - if statement.Engine.dialect.DBType() == core.POSTGRES { - paraStr = "$" - } else if statement.Engine.dialect.DBType() == core.MSSQL { - paraStr = ":" - } - - if paraStr != "" { - if strings.Contains(sqls[1], paraStr) { - dollers := strings.Split(sqls[1], paraStr) - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) - } - } - } - - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.Engine.Quote(statement.TableName()), - whereStr) -} diff --git a/statement_columnmap.go b/statement_columnmap.go deleted file mode 100644 index 18c3e86..0000000 --- a/statement_columnmap.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "strings" - -type columnMap []string - -func (m columnMap) contain(colName string) bool { - if len(m) == 0 { - return false - } - - n := len(colName) - for _, mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, colName) { - return true - } - } - - return false -} - -func (m *columnMap) add(colName string) bool { - if m.contain(colName) { - return false - } - *m = append(*m, colName) - return true -} diff --git a/statement_quote.go b/statement_quote.go deleted file mode 100644 index fc50169..0000000 --- a/statement_quote.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -func trimQuote(s string) string { - if len(s) == 0 { - return s - } - - if s[0] == '`' { - s = s[1:] - } - if len(s) > 0 && s[len(s)-1] == '`' { - return s[:len(s)-1] - } - return s -} diff --git a/table_name.go b/table_name.go deleted file mode 100644 index 3a9617b..0000000 --- a/table_name.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "reflect" - - "github.com/xormplus/core" -) - -func getTableName(mapper core.IMapper, v reflect.Value) string { - if t, ok := v.Interface().(TableName); ok { - return t.TableName() - } - if v.Type().Implements(tpTableName) { - return v.Interface().(TableName).TableName() - } - if v.Kind() == reflect.Ptr { - v = v.Elem() - if t, ok := v.Interface().(TableName); ok { - return t.TableName() - } - if v.Type().Implements(tpTableName) { - return v.Interface().(TableName).TableName() - } - } - - return mapper.Obj2Table(v.Type().Name()) -} diff --git a/table_name_test.go b/table_name_test.go deleted file mode 100644 index dee7828..0000000 --- a/table_name_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/xormplus/core" -) - -type TestTableNameStruct struct{} - -func (t *TestTableNameStruct) TableName() string { - return "my_test_table_name_struct" -} - -func TestGetTableName(t *testing.T) { - var kases = []struct { - mapper core.IMapper - v reflect.Value - expectedTableName string - }{ - { - core.SnakeMapper{}, - reflect.ValueOf(new(Userinfo)), - "userinfo", - }, - { - core.SnakeMapper{}, - reflect.ValueOf(Userinfo{}), - "userinfo", - }, - { - core.SameMapper{}, - reflect.ValueOf(new(Userinfo)), - "Userinfo", - }, - { - core.SameMapper{}, - reflect.ValueOf(Userinfo{}), - "Userinfo", - }, - { - core.SnakeMapper{}, - reflect.ValueOf(new(MyGetCustomTableImpletation)), - getCustomTableName, - }, - { - core.SnakeMapper{}, - reflect.ValueOf(MyGetCustomTableImpletation{}), - getCustomTableName, - }, - { - core.SnakeMapper{}, - reflect.ValueOf(MyGetCustomTableImpletation{}), - getCustomTableName, - }, - { - core.SnakeMapper{}, - reflect.ValueOf(new(TestTableNameStruct)), - new(TestTableNameStruct).TableName(), - }, - } - - for _, kase := range kases { - assert.EqualValues(t, kase.expectedTableName, getTableName(kase.mapper, kase.v)) - } -} diff --git a/tag_cache_test.go b/tag_cache_test.go deleted file mode 100644 index 30e2c51..0000000 --- a/tag_cache_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCacheTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CacheDomain struct { - Id int64 `xorm:"pk cache"` - Name string - } - - assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) - assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil) -} - -func TestNoCacheTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type NoCacheDomain struct { - Id int64 `xorm:"pk nocache"` - Name string - } - - assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) - assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil) -} diff --git a/tag_extends_test.go b/tag_extends_test.go deleted file mode 100644 index de52914..0000000 --- a/tag_extends_test.go +++ /dev/null @@ -1,608 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/xormplus/core" -) - -type tempUser struct { - Id int64 - Username string -} - -type tempUser2 struct { - TempUser tempUser `xorm:"extends"` - Departname string -} - -type tempUser3 struct { - Temp *tempUser `xorm:"extends"` - Departname string -} - -type tempUser4 struct { - TempUser2 tempUser2 `xorm:"extends"` -} - -type Userinfo struct { - Uid int64 `xorm:"id pk not null autoincr"` - Username string `xorm:"unique"` - Departname string - Alias string `xorm:"-"` - Created time.Time - Detail Userdetail `xorm:"detail_id int(11)"` - Height float64 - Avatar []byte - IsMan bool -} - -type Userdetail struct { - Id int64 - Intro string `xorm:"text"` - Profile string `xorm:"varchar(2000)"` -} - -type UserAndDetail struct { - Userinfo `xorm:"extends"` - Userdetail `xorm:"extends"` -} - -func TestExtends(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&tempUser2{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser2{}) - assert.NoError(t, err) - - tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} - _, err = testEngine.Insert(tu) - assert.NoError(t, err) - - tu2 := &tempUser2{} - _, err = testEngine.Get(tu2) - assert.NoError(t, err) - - tu3 := &tempUser2{tempUser{0, "extends update"}, ""} - _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) - assert.NoError(t, err) - - err = testEngine.DropTables(&tempUser4{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser4{}) - assert.NoError(t, err) - - tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} - _, err = testEngine.Insert(tu8) - assert.NoError(t, err) - - tu9 := &tempUser4{} - _, err = testEngine.Get(tu9) - assert.NoError(t, err) - - if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { - err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) - t.Error(err) - panic(err) - } - - tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} - _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) - assert.NoError(t, err) - - err = testEngine.DropTables(&tempUser3{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser3{}) - assert.NoError(t, err) - - tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} - _, err = testEngine.Insert(tu4) - assert.NoError(t, err) - - tu5 := &tempUser3{} - _, err = testEngine.Get(tu5) - assert.NoError(t, err) - - if tu5.Temp == nil { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } - if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" || - tu5.Departname != "dev depart" { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } - - tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} - _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - assert.NoError(t, err) - - users := make([]tempUser3, 0) - err = testEngine.Find(&users) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(users), "error get data not 1") - - assertSync(t, new(Userinfo), new(Userdetail)) - - detail := Userdetail{ - Intro: "I'm in China", - } - _, err = testEngine.Insert(&detail) - assert.NoError(t, err) - - _, err = testEngine.Insert(&Userinfo{ - Username: "lunny", - Detail: detail, - }) - assert.NoError(t, err) - - var info UserAndDetail - qt := testEngine.Quote - ui := testEngine.TableName(new(Userinfo), true) - ud := testEngine.TableName(&detail, true) - uiid := testEngine.GetColumnMapper().Obj2Table("Id") - udid := "detail_id" - sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", - qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) - b, err := testEngine.SQL(sql).NoCascade().Get(&info) - assert.NoError(t, err) - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } - - fmt.Println("----join--info2") - var info2 UserAndDetail - b, err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). - NoCascade().Get(&info2) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } - fmt.Println(info2) - - fmt.Println("----join--infos2") - var infos2 = make([]UserAndDetail, 0) - err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). - NoCascade(). - Find(&infos2) - assert.NoError(t, err) - fmt.Println(infos2) -} - -type MessageBase struct { - Id int64 `xorm:"int(11) pk autoincr"` - TypeId int64 `xorm:"int(11) notnull"` -} - -type Message struct { - MessageBase `xorm:"extends"` - Title string `xorm:"varchar(100) notnull"` - Content string `xorm:"text notnull"` - Uid int64 `xorm:"int(11) notnull"` - ToUid int64 `xorm:"int(11) notnull"` - CreateTime time.Time `xorm:"datetime notnull created"` -} - -type MessageUser struct { - Id int64 - Name string -} - -type MessageType struct { - Id int64 - Name string -} - -type MessageExtend3 struct { - Message `xorm:"extends"` - Sender MessageUser `xorm:"extends"` - Receiver MessageUser `xorm:"extends"` - Type MessageType `xorm:"extends"` -} - -type MessageExtend4 struct { - Message `xorm:"extends"` - MessageUser `xorm:"extends"` - MessageType `xorm:"extends"` -} - -func TestExtends2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - assert.NoError(t, err) - - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &receiver, &msgtype) - assert.NoError(t, err) - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - ToUid: receiver.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - cnt, err := session.Insert(&msg) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]Message, 0) - err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - assert.NoError(t, err) - - assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list))) - assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg)) -} - -func TestExtends3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &receiver, &msgtype) - if err != nil { - t.Error(err) - panic(err) - } - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - ToUid: receiver.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - _, err = session.Insert(&msg) - assert.NoError(t, err) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]MessageExtend3, 0) - err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - assert.NoError(t, err) - - if len(list) != 1 { - err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) - t.Error(err) - panic(err) - } - - if list[0].Message.Id != msg.Id { - err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) - t.Error(err) - panic(err) - } - - if list[0].Sender.Id != sender.Id || list[0].Sender.Name != sender.Name { - err = errors.New(fmt.Sprintln("should sender equal", list[0].Sender, sender)) - t.Error(err) - panic(err) - } - - if list[0].Receiver.Id != receiver.Id || list[0].Receiver.Name != receiver.Name { - err = errors.New(fmt.Sprintln("should receiver equal", list[0].Receiver, receiver)) - t.Error(err) - panic(err) - } - - if list[0].Type.Id != msgtype.Id || list[0].Type.Name != msgtype.Name { - err = errors.New(fmt.Sprintln("should msgtype equal", list[0].Type, msgtype)) - t.Error(err) - panic(err) - } -} - -func TestExtends4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sender = MessageUser{Name: "sender"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &msgtype) - if err != nil { - t.Error(err) - panic(err) - } - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - _, err = session.Insert(&msg) - assert.NoError(t, err) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]MessageExtend4, 0) - err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - if err != nil { - t.Error(err) - panic(err) - } - - if len(list) != 1 { - err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) - t.Error(err) - panic(err) - } - - if list[0].Message.Id != msg.Id { - err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) - t.Error(err) - panic(err) - } - - if list[0].MessageUser.Id != sender.Id || list[0].MessageUser.Name != sender.Name { - err = errors.New(fmt.Sprintln("should sender equal", list[0].MessageUser, sender)) - t.Error(err) - panic(err) - } - - if list[0].MessageType.Id != msgtype.Id || list[0].MessageType.Name != msgtype.Name { - err = errors.New(fmt.Sprintln("should msgtype equal", list[0].MessageType, msgtype)) - t.Error(err) - panic(err) - } -} - -type Size struct { - ID int64 `xorm:"int(4) 'id' pk autoincr"` - Width float32 `json:"width" xorm:"float 'Width'"` - Height float32 `json:"height" xorm:"float 'Height'"` -} - -type Book struct { - ID int64 `xorm:"int(4) 'id' pk autoincr"` - SizeOpen *Size `xorm:"extends('Open')"` - SizeClosed *Size `xorm:"extends('Closed')"` - Size *Size `xorm:"extends('')"` -} - -func TestExtends5(t *testing.T) { - assert.NoError(t, prepareEngine()) - err := testEngine.DropTables(&Book{}, &Size{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Size{}, &Book{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sc = Size{Width: 0.2, Height: 0.4} - var so = Size{Width: 0.2, Height: 0.8} - var s = Size{Width: 0.15, Height: 1.5} - var bk1 = Book{ - SizeOpen: &so, - SizeClosed: &sc, - Size: &s, - } - var bk2 = Book{ - SizeOpen: &so, - } - var bk3 = Book{ - SizeClosed: &sc, - Size: &s, - } - var bk4 = Book{} - var bk5 = Book{Size: &s} - _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) - if err != nil { - t.Fatal(err) - } - - var books = map[int64]Book{ - bk1.ID: bk1, - bk2.ID: bk2, - bk3.ID: bk3, - bk4.ID: bk4, - bk5.ID: bk5, - } - - session := testEngine.NewSession() - defer session.Close() - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - bookTableName := quote(testEngine.TableName(mapper("Book"), true)) - sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) - - list := make([]Book, 0) - err = session. - Select(fmt.Sprintf( - "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", - quote(bookTableName), - quote("id"), - quote("Width"), - quote("ClosedWidth"), - quote("Height"), - quote("ClosedHeight"), - quote("Width"), - quote("Height"), - )). - Table(bookTableName). - Join( - "LEFT", - sizeTableName+" AS `sc`", - bookTableName+".`SizeClosed`=sc.`id`", - ). - Join( - "LEFT", - sizeTableName+" AS `s`", - bookTableName+".`Size`=s.`id`", - ). - Find(&list) - if err != nil { - t.Error(err) - panic(err) - } - - for _, book := range list { - if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok { - t.Error("Not bounded size closed") - panic("Not bounded size closed") - } - - if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok { - t.Error("Not bounded size closed") - panic("Not bounded size closed") - } - - if books[book.ID].Size != nil || book.Size != nil { - if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok { - t.Error("Not bounded size") - panic("Not bounded size") - } - - if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok { - t.Error("Not bounded size") - panic("Not bounded size") - } - } - } -} diff --git a/tag_id_test.go b/tag_id_test.go deleted file mode 100644 index d597bfa..0000000 --- a/tag_id_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xormplus/core" -) - -type IDGonicMapper struct { - ID int64 -} - -func TestGonicMapperID(t *testing.T) { - assert.NoError(t, prepareEngine()) - - oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) - testEngine.SetMapper(core.LintGonicMapper) - defer func() { - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) - testEngine.SetMapper(oldMapper) - }() - - err := testEngine.CreateTables(new(IDGonicMapper)) - if err != nil { - t.Fatal(err) - } - - tables, err := testEngine.DBMetas() - if err != nil { - t.Fatal(err) - } - - for _, tb := range tables { - if tb.Name == "id_gonic_mapper" { - if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { - t.Fatal(tb) - } - return - } - } - - t.Fatal("not table id_gonic_mapper") -} - -type IDSameMapper struct { - ID int64 -} - -func TestSameMapperID(t *testing.T) { - assert.NoError(t, prepareEngine()) - - oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) - testEngine.SetMapper(core.SameMapper{}) - defer func() { - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) - testEngine.SetMapper(oldMapper) - }() - - err := testEngine.CreateTables(new(IDSameMapper)) - if err != nil { - t.Fatal(err) - } - - tables, err := testEngine.DBMetas() - if err != nil { - t.Fatal(err) - } - - for _, tb := range tables { - if tb.Name == "IDSameMapper" { - if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { - t.Fatalf("tb %s tb.PKColumns() is %d not 1, tb.PKColumns()[0].Name is %s not ID", tb.Name, len(tb.PKColumns()), tb.PKColumns()[0].Name) - } - return - } - } - t.Fatal("not table IDSameMapper") -} diff --git a/tag_test.go b/tag_test.go deleted file mode 100644 index 4f5d1d6..0000000 --- a/tag_test.go +++ /dev/null @@ -1,601 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/xormplus/core" -) - -type UserCU struct { - Id int64 - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - -func TestCreatedAndUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) - - u := new(UserCU) - err := testEngine.DropTables(u) - assert.NoError(t, err) - - err = testEngine.CreateTables(u) - assert.NoError(t, err) - - u.Name = "sss" - cnt, err := testEngine.Insert(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - u.Name = "xxx" - cnt, err = testEngine.ID(u.Id).Update(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - u.Id = 0 - u.Created = time.Now().Add(-time.Hour * 24 * 365) - u.Updated = u.Created - fmt.Println(u) - cnt, err = testEngine.NoAutoTime().Insert(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) -} - -type StrangeName struct { - Id_t int64 `xorm:"pk autoincr"` - Name string -} - -func TestStrangeName(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(StrangeName)) - assert.NoError(t, err) - - err = testEngine.CreateTables(new(StrangeName)) - assert.NoError(t, err) - - _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) - assert.NoError(t, err) - - beans := make([]StrangeName, 0) - err = testEngine.Find(&beans) - assert.NoError(t, err) -} - -func TestCreatedUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CreatedUpdated struct { - Id int64 - Name string - Value float64 `xorm:"numeric"` - Created time.Time `xorm:"created"` - Created2 time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` - } - - err := testEngine.Sync2(&CreatedUpdated{}) - assert.NoError(t, err) - - c := &CreatedUpdated{Name: "test"} - _, err = testEngine.Insert(c) - assert.NoError(t, err) - - c2 := new(CreatedUpdated) - has, err := testEngine.ID(c.Id).Get(c2) - assert.NoError(t, err) - - assert.True(t, has) - - c2.Value -= 1 - _, err = testEngine.ID(c2.Id).Update(c2) - assert.NoError(t, err) -} - -func TestCreatedUpdatedInt64(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CreatedUpdatedInt64 struct { - Id int64 - Name string - Value float64 `xorm:"numeric"` - Created int64 `xorm:"created"` - Created2 int64 `xorm:"created"` - Updated int64 `xorm:"updated"` - } - - assertSync(t, &CreatedUpdatedInt64{}) - - c := &CreatedUpdatedInt64{Name: "test"} - _, err := testEngine.Insert(c) - assert.NoError(t, err) - - c2 := new(CreatedUpdatedInt64) - has, err := testEngine.ID(c.Id).Get(c2) - assert.NoError(t, err) - assert.True(t, has) - - c2.Value -= 1 - _, err = testEngine.ID(c2.Id).Update(c2) - assert.NoError(t, err) -} - -type Lowercase struct { - Id int64 - Name string - ended int64 `xorm:"-"` -} - -func TestLowerCase(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.Sync2(&Lowercase{}) - assert.NoError(t, err) - _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) - assert.NoError(t, err) - - _, err = testEngine.Insert(&Lowercase{ended: 1}) - assert.NoError(t, err) - - ls := make([]Lowercase, 0) - err = testEngine.Find(&ls) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(ls)) -} - -func TestAutoIncrTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TestAutoIncr1 struct { - Id int64 - } - - tb := testEngine.TableInfo(new(TestAutoIncr1)) - cols := tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.True(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr2 struct { - Id int64 `xorm:"id"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr2)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr3 struct { - Id int64 `xorm:"'ID'"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr3)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "ID", cols[0].Name) - - type TestAutoIncr4 struct { - Id int64 `xorm:"pk"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr4)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) -} - -func TestTagComment(t *testing.T) { - assert.NoError(t, prepareEngine()) - // FIXME: only support mysql - if testEngine.Dialect().DriverName() != core.MYSQL { - return - } - - type TestComment1 struct { - Id int64 `xorm:"comment(主键)"` - } - - assert.NoError(t, testEngine.Sync2(new(TestComment1))) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, 1, len(tables[0].Columns())) - assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) - - assert.NoError(t, testEngine.DropTables(new(TestComment1))) - - type TestComment2 struct { - Id int64 `xorm:"comment('主键')"` - } - - assert.NoError(t, testEngine.Sync2(new(TestComment2))) - - tables, err = testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, 1, len(tables[0].Columns())) - assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) -} - -func TestTagDefault(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct struct { - Id int64 - Name string - Age int `xorm:"default(10)"` - } - - assertSync(t, new(DefaultStruct)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("age") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "10", defaultVal) - - cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ - Name: "test", - Age: 20, - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s DefaultStruct - has, err := testEngine.ID(1).Get(&s) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, 10, s.Age) - assert.EqualValues(t, "test", s.Name) -} - -func TestTagDefault2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct2 struct { - Id int64 - Name string - } - - assertSync(t, new(DefaultStruct2)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct2") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("name") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.False(t, isDefaultExist, fmt.Sprintf("default value is --%v--", defaultVal)) - assert.EqualValues(t, "", defaultVal) -} - -func TestTagDefault3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct3 struct { - Id int64 - Name string `xorm:"default('myname')"` - } - - assertSync(t, new(DefaultStruct3)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct3") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("name") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "'myname'", defaultVal) -} - -func TestTagDefault4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct4 struct { - Id int64 - Created time.Time `xorm:"default(CURRENT_TIMESTAMP)"` - } - - assertSync(t, new(DefaultStruct4)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("created") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.True(t, "CURRENT_TIMESTAMP" == defaultVal || - "now()" == defaultVal || - "getdate" == defaultVal, defaultVal) -} - -func TestTagDefault5(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct5 struct { - Id int64 - Created time.Time `xorm:"default('2006-01-02 15:04:05')"` - } - - assertSync(t, new(DefaultStruct5)) - table := testEngine.TableInfo(new(DefaultStruct5)) - createdCol := table.GetColumn("created") - assert.NotNil(t, createdCol) - assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) - assert.False(t, createdCol.DefaultIsEmpty) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("created") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) -} - -func TestTagDefault6(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct6 struct { - Id int64 - IsMan bool `xorm:"default(true)"` - } - - assertSync(t, new(DefaultStruct6)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("is_man") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - if defaultVal == "1" { - defaultVal = "true" - } else if defaultVal == "0" { - defaultVal = "false" - } - assert.EqualValues(t, "true", defaultVal) -} - -func TestTagsDirection(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type OnlyFromDBStruct struct { - Id int64 - Name string - Uuid string `xorm:"<- default '1'"` - } - - assertSync(t, new(OnlyFromDBStruct)) - - cnt, err := testEngine.Insert(&OnlyFromDBStruct{ - Name: "test", - Uuid: "2", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s OnlyFromDBStruct - has, err := testEngine.ID(1).Get(&s) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "1", s.Uuid) - assert.EqualValues(t, "test", s.Name) - - cnt, err = testEngine.ID(1).Update(&OnlyFromDBStruct{ - Uuid: "3", - Name: "test1", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s3 OnlyFromDBStruct - has, err = testEngine.ID(1).Get(&s3) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "1", s3.Uuid) - assert.EqualValues(t, "test1", s3.Name) - - type OnlyToDBStruct struct { - Id int64 - Name string - Uuid string `xorm:"->"` - } - - assertSync(t, new(OnlyToDBStruct)) - - cnt, err = testEngine.Insert(&OnlyToDBStruct{ - Name: "test", - Uuid: "2", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s2 OnlyToDBStruct - has, err = testEngine.ID(1).Get(&s2) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "", s2.Uuid) - assert.EqualValues(t, "test", s2.Name) -} - -func TestTagTime(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TagUTCStruct struct { - Id int64 - Name string - Created time.Time `xorm:"created utc"` - } - - assertSync(t, new(TagUTCStruct)) - - assert.EqualValues(t, time.Local.String(), testEngine.GetTZLocation().String()) - - s := TagUTCStruct{ - Name: "utc", - } - cnt, err := testEngine.Insert(&s) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var u TagUTCStruct - has, err := testEngine.ID(1).Get(&u) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"), u.Created.Format("2006-01-02 15:04:05")) - - var tm string - has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), - strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) -} - -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} - -unc TestTagAutoIncr(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TagAutoIncr struct { - Id int64 - Name string - } - - assertSync(t, new(TagAutoIncr)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name) - col := tables[0].GetColumn(colMapper.Obj2Table("Id")) - assert.NotNil(t, col) - assert.True(t, col.IsPrimaryKey) - assert.True(t, col.IsAutoIncrement) - - col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) - assert.NotNil(t, col2) - assert.False(t, col2.IsPrimaryKey) - assert.False(t, col2.IsAutoIncrement) -} - -func TestTagPrimarykey(t *testing.T) { - assert.NoError(t, prepareEngine()) - type TagPrimaryKey struct { - Id int64 `xorm:"pk"` - Name string `xorm:"VARCHAR(20) pk"` - } - - assertSync(t, new(TagPrimaryKey)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name) - col := tables[0].GetColumn(colMapper.Obj2Table("Id")) - assert.NotNil(t, col) - assert.True(t, col.IsPrimaryKey) - assert.False(t, col.IsAutoIncrement) - - col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) - assert.NotNil(t, col2) - assert.True(t, col2.IsPrimaryKey) - assert.False(t, col2.IsAutoIncrement) -} diff --git a/tag_version_test.go b/tag_version_test.go deleted file mode 100644 index cd6dc93..0000000 --- a/tag_version_test.go +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -type VersionS struct { - Id int64 - Name string - Ver int `xorm:"version"` - Created time.Time `xorm:"created"` -} - -func TestVersion1(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - ver := &VersionS{Name: "sfsfdsfds"} - _, err = testEngine.Insert(ver) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ver) - if ver.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer := new(VersionS) - has, err := testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - - if !has { - t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer.Name = "-------" - _, err = testEngine.ID(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - panic(err) - } - if newVer.Ver != 2 { - err = errors.New("update should set version back to struct") - t.Error(err) - } - - newVer = new(VersionS) - has, err = testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 2 { - err = errors.New("update error") - t.Error(err) - panic(err) - } -} - -func TestVersion2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - var vers = []VersionS{ - {Name: "sfsfdsfds"}, - {Name: "xxxxx"}, - } - _, err = testEngine.Insert(vers) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(vers) - - for _, v := range vers { - if v.Ver != 1 { - err := errors.New("version should be 1") - t.Error(err) - panic(err) - } - } -} - -type VersionUintS struct { - Id int64 - Name string - Ver uint `xorm:"version"` - Created time.Time `xorm:"created"` -} - -func TestVersion3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - ver := &VersionUintS{Name: "sfsfdsfds"} - _, err = testEngine.Insert(ver) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ver) - if ver.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer := new(VersionUintS) - has, err := testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - - if !has { - t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer.Name = "-------" - _, err = testEngine.ID(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - panic(err) - } - if newVer.Ver != 2 { - err = errors.New("update should set version back to struct") - t.Error(err) - } - - newVer = new(VersionUintS) - has, err = testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 2 { - err = errors.New("update error") - t.Error(err) - panic(err) - } -} - -func TestVersion4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - var vers = []VersionUintS{ - {Name: "sfsfdsfds"}, - {Name: "xxxxx"}, - } - _, err = testEngine.Insert(vers) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(vers) - - for _, v := range vers { - if v.Ver != 1 { - err := errors.New("version should be 1") - t.Error(err) - panic(err) - } - } -} diff --git a/tags/parser.go b/tags/parser.go new file mode 100644 index 0000000..eadedb8 --- /dev/null +++ b/tags/parser.go @@ -0,0 +1,307 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "encoding/gob" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/convert" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" +) + +var ( + ErrUnsupportedType = errors.New("Unsupported type") +) + +type Parser struct { + identifier string + dialect dialects.Dialect + columnMapper names.Mapper + tableMapper names.Mapper + handlers map[string]Handler + cacherMgr *caches.Manager + tableCache sync.Map // map[reflect.Type]*schemas.Table +} + +func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { + return &Parser{ + identifier: identifier, + dialect: dialect, + tableMapper: tableMapper, + columnMapper: columnMapper, + handlers: defaultTagHandlers, + cacherMgr: cacherMgr, + } +} + +func (parser *Parser) GetTableMapper() names.Mapper { + return parser.tableMapper +} + +func (parser *Parser) SetTableMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.tableMapper = mapper +} + +func (parser *Parser) GetColumnMapper() names.Mapper { + return parser.columnMapper +} + +func (parser *Parser) SetColumnMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.columnMapper = mapper +} + +func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { + t := v.Type() + tableI, ok := parser.tableCache.Load(t) + if ok { + return tableI.(*schemas.Table), nil + } + + table, err := parser.Parse(v) + if err != nil { + return nil, err + } + + parser.tableCache.Store(t, table) + + if parser.cacherMgr.GetDefaultCacher() != nil { + if v.CanAddr() { + gob.Register(v.Addr().Interface()) + } else { + gob.Register(v.Interface()) + } + } + + return table, nil +} + +// ClearCacheTable removes the database mapper of a type from the cache +func (parser *Parser) ClearCacheTable(t reflect.Type) { + parser.tableCache.Delete(t) +} + +// ClearCaches removes all the cached table information parsed by structs +func (parser *Parser) ClearCaches() { + parser.tableCache = sync.Map{} +} + +func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = indexType + } else { + index := schemas.NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = indexType + } +} + +// Parse parses a struct as a table information +func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, ErrUnsupportedType + } + + table := schemas.NewEmptyTable() + table.Type = t + table.Name = names.GetTableName(parser.tableMapper, v) + + var idFieldColName string + var hasCacheTag, hasNoCacheTag bool + + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + + ormTagStr := tag.Get(parser.identifier) + var col *schemas.Column + fieldValue := v.Field(i) + fieldType := fieldValue.Type() + + if ormTagStr != "" { + col = &schemas.Column{ + FieldName: t.Field(i).Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + tags := splitTag(ormTagStr) + + if len(tags) > 0 { + if tags[0] == "-" { + continue + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { + pStart := strings.Index(tags[0], "(") + if pStart > -1 && strings.HasSuffix(tags[0], ")") { + var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { + return r == '\'' || r == '"' + }) + + ctx.params = []string{tagPrefix} + } + + if err := ExtendsTagHandler(&ctx); err != nil { + return nil, err + } + continue + } + + for j, key := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + k := strings.ToUpper(key) + ctx.tagName = k + ctx.params = []string{} + + pStart := strings.Index(k, "(") + if pStart == 0 { + return nil, errors.New("( could not be the first character") + } + if pStart > -1 { + if !strings.HasSuffix(k, ")") { + return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) + } + + ctx.tagName = k[:pStart] + ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") + } + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1]) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1] + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagName]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { + col.Name = key[1 : len(key)-1] + } else { + col.Name = key + } + } + + if ctx.hasCacheTag { + hasCacheTag = true + } + if ctx.hasNoCacheTag { + hasNoCacheTag = true + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(fieldType) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + } + } else { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(fieldType) + } + col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), + t.Field(i).Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + + if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + idFieldColName = col.Name + } + } + if col.IsAutoIncrement { + col.Nullable = false + } + + table.AddColumn(col) + + } // end for + + if idFieldColName != "" && len(table.PrimaryKeys) == 0 { + col := table.GetColumn(idFieldColName) + col.IsPrimaryKey = true + col.IsAutoIncrement = true + col.Nullable = false + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + table.AutoIncrement = col.Name + } + + if hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided + //engine.logger.Info("enable cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + //engine.logger.Info("enable LRU cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if hasNoCacheTag { + //engine.logger.Info("disable cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, nil) + } + + return table, nil +} diff --git a/tags/parser_test.go b/tags/parser_test.go new file mode 100644 index 0000000..1b735ac --- /dev/null +++ b/tags/parser_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xormplus/dialects" + "github.com/xormplus/names" + "github.com/xormplus/xorm/caches" +) + +type ParseTableName1 struct{} + +type ParseTableName2 struct{} + +func (p ParseTableName2) TableName() string { + return "p_parseTableName" +} + +func TestParseTableName(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + assert.NoError(t, err) + assert.EqualValues(t, "parse_table_name1", table.Name) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) + + table, err = parser.Parse(reflect.ValueOf(ParseTableName2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) +} diff --git a/tag.go b/tags/tag.go similarity index 72% rename from tag.go rename to tags/tag.go index b5b711a..257dd74 100644 --- a/tag.go +++ b/tags/tag.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package tags import ( "fmt" @@ -11,31 +11,52 @@ import ( "strings" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/schemas" ) -type tagContext struct { +func splitTag(tag string) (tags []string) { + tag = strings.TrimSpace(tag) + var hasQuote = false + var lastIdx = 0 + for i, t := range tag { + if t == '\'' { + hasQuote = !hasQuote + } else if t == ' ' { + if lastIdx < i && !hasQuote { + tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) + lastIdx = i + 1 + } + } + } + if lastIdx < len(tag) { + tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + } + return +} + +// Context represents a context for xorm tag parse. +type Context struct { tagName string params []string preTag, nextTag string - table *core.Table - col *core.Column + table *schemas.Table + col *schemas.Column fieldValue reflect.Value isIndex bool isUnique bool indexNames map[string]int - engine *Engine + parser *Parser hasCacheTag bool hasNoCacheTag bool ignoreNext bool } -// tagHandler describes tag handler for XORM -type tagHandler func(ctx *tagContext) error +// Handler describes tag handler for XORM +type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler - defaultTagHandlers = map[string]tagHandler{ + defaultTagHandlers = map[string]Handler{ "<-": OnlyFromDBTagHandler, "->": OnlyToDBTagHandler, "PK": PKTagHandler, @@ -59,49 +80,49 @@ var ( ) func init() { - for k := range core.SqlTypes { + for k := range schemas.SqlTypes { defaultTagHandlers[k] = SQLTypeTagHandler } } // IgnoreTagHandler describes ignored tag handler -func IgnoreTagHandler(ctx *tagContext) error { +func IgnoreTagHandler(ctx *Context) error { return nil } // OnlyFromDBTagHandler describes mapping direction tag handler -func OnlyFromDBTagHandler(ctx *tagContext) error { - ctx.col.MapType = core.ONLYFROMDB +func OnlyFromDBTagHandler(ctx *Context) error { + ctx.col.MapType = schemas.ONLYFROMDB return nil } // OnlyToDBTagHandler describes mapping direction tag handler -func OnlyToDBTagHandler(ctx *tagContext) error { - ctx.col.MapType = core.ONLYTODB +func OnlyToDBTagHandler(ctx *Context) error { + ctx.col.MapType = schemas.ONLYTODB return nil } -// PKTagHandler decribes primary key tag handler -func PKTagHandler(ctx *tagContext) error { +// PKTagHandler describes primary key tag handler +func PKTagHandler(ctx *Context) error { ctx.col.IsPrimaryKey = true ctx.col.Nullable = false return nil } // NULLTagHandler describes null tag handler -func NULLTagHandler(ctx *tagContext) error { +func NULLTagHandler(ctx *Context) error { ctx.col.Nullable = (strings.ToUpper(ctx.preTag) != "NOT") return nil } // NotNullTagHandler describes notnull tag handler -func NotNullTagHandler(ctx *tagContext) error { +func NotNullTagHandler(ctx *Context) error { ctx.col.Nullable = false return nil } // AutoIncrTagHandler describes autoincr tag handler -func AutoIncrTagHandler(ctx *tagContext) error { +func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true /* if len(ctx.params) > 0 { @@ -118,7 +139,7 @@ func AutoIncrTagHandler(ctx *tagContext) error { } // DefaultTagHandler describes default tag handler -func DefaultTagHandler(ctx *tagContext) error { +func DefaultTagHandler(ctx *Context) error { if len(ctx.params) > 0 { ctx.col.Default = ctx.params[0] } else { @@ -130,26 +151,26 @@ func DefaultTagHandler(ctx *tagContext) error { } // CreatedTagHandler describes created tag handler -func CreatedTagHandler(ctx *tagContext) error { +func CreatedTagHandler(ctx *Context) error { ctx.col.IsCreated = true return nil } // VersionTagHandler describes version tag handler -func VersionTagHandler(ctx *tagContext) error { +func VersionTagHandler(ctx *Context) error { ctx.col.IsVersion = true ctx.col.Default = "1" return nil } // UTCTagHandler describes utc tag handler -func UTCTagHandler(ctx *tagContext) error { +func UTCTagHandler(ctx *Context) error { ctx.col.TimeZone = time.UTC return nil } // LocalTagHandler describes local tag handler -func LocalTagHandler(ctx *tagContext) error { +func LocalTagHandler(ctx *Context) error { if len(ctx.params) == 0 { ctx.col.TimeZone = time.Local } else { @@ -163,21 +184,21 @@ func LocalTagHandler(ctx *tagContext) error { } // UpdatedTagHandler describes updated tag handler -func UpdatedTagHandler(ctx *tagContext) error { +func UpdatedTagHandler(ctx *Context) error { ctx.col.IsUpdated = true return nil } // DeletedTagHandler describes deleted tag handler -func DeletedTagHandler(ctx *tagContext) error { +func DeletedTagHandler(ctx *Context) error { ctx.col.IsDeleted = true return nil } // IndexTagHandler describes index tag handler -func IndexTagHandler(ctx *tagContext) error { +func IndexTagHandler(ctx *Context) error { if len(ctx.params) > 0 { - ctx.indexNames[ctx.params[0]] = core.IndexType + ctx.indexNames[ctx.params[0]] = schemas.IndexType } else { ctx.isIndex = true } @@ -185,9 +206,9 @@ func IndexTagHandler(ctx *tagContext) error { } // UniqueTagHandler describes unique tag handler -func UniqueTagHandler(ctx *tagContext) error { +func UniqueTagHandler(ctx *Context) error { if len(ctx.params) > 0 { - ctx.indexNames[ctx.params[0]] = core.UniqueType + ctx.indexNames[ctx.params[0]] = schemas.UniqueType } else { ctx.isUnique = true } @@ -195,7 +216,7 @@ func UniqueTagHandler(ctx *tagContext) error { } // CommentTagHandler add comment to column -func CommentTagHandler(ctx *tagContext) error { +func CommentTagHandler(ctx *Context) error { if len(ctx.params) > 0 { ctx.col.Comment = strings.Trim(ctx.params[0], "' ") } @@ -203,17 +224,17 @@ func CommentTagHandler(ctx *tagContext) error { } // SQLTypeTagHandler describes SQL Type tag handler -func SQLTypeTagHandler(ctx *tagContext) error { - ctx.col.SQLType = core.SQLType{Name: ctx.tagName} +func SQLTypeTagHandler(ctx *Context) error { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} if len(ctx.params) > 0 { - if ctx.tagName == core.Enum { + if ctx.tagName == schemas.Enum { ctx.col.EnumOptions = make(map[string]int) for k, v := range ctx.params { v = strings.TrimSpace(v) v = strings.Trim(v, "'") ctx.col.EnumOptions[v] = k } - } else if ctx.tagName == core.Set { + } else if ctx.tagName == schemas.Set { ctx.col.SetOptions = make(map[string]int) for k, v := range ctx.params { v = strings.TrimSpace(v) @@ -243,7 +264,7 @@ func SQLTypeTagHandler(ctx *tagContext) error { } // ExtendsTagHandler describes extends tag handler -func ExtendsTagHandler(ctx *tagContext) error { +func ExtendsTagHandler(ctx *Context) error { var fieldValue = ctx.fieldValue var isPtr = false switch fieldValue.Kind() { @@ -259,7 +280,7 @@ func ExtendsTagHandler(ctx *tagContext) error { isPtr = true fallthrough case reflect.Struct: - parentTable, err := ctx.engine.mapType(fieldValue) + parentTable, err := ctx.parser.Parse(fieldValue) if err != nil { return err } @@ -295,7 +316,7 @@ func ExtendsTagHandler(ctx *tagContext) error { } // CacheTagHandler describes cache tag handler -func CacheTagHandler(ctx *tagContext) error { +func CacheTagHandler(ctx *Context) error { if !ctx.hasCacheTag { ctx.hasCacheTag = true } @@ -303,7 +324,7 @@ func CacheTagHandler(ctx *tagContext) error { } // NoCacheTagHandler describes nocache tag handler -func NoCacheTagHandler(ctx *tagContext) error { +func NoCacheTagHandler(ctx *Context) error { if !ctx.hasNoCacheTag { ctx.hasNoCacheTag = true } diff --git a/tags/tag_test.go b/tags/tag_test.go new file mode 100644 index 0000000..c9aaef2 --- /dev/null +++ b/tags/tag_test.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "testing" + + "github.com/xormplus/xorm/internal/utils" +) + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !utils.SliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} diff --git a/transaction.go b/transaction.go deleted file mode 100644 index 4104103..0000000 --- a/transaction.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -// Transaction Execute sql wrapped in a transaction(abbr as tx), tx will automatic commit if no errors occurred -func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interface{}, error) { - session := engine.NewSession() - defer session.Close() - - if err := session.Begin(); err != nil { - return nil, err - } - - result, err := f(session) - if err != nil { - return nil, err - } - - if err := session.Commit(); err != nil { - return nil, err - } - - return result, nil -} diff --git a/transancation_test.go b/transancation_test.go deleted file mode 100644 index b9a8987..0000000 --- a/transancation_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestAutoTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TestTx struct { - Id int64 `xorm:"autoincr pk"` - Msg string `xorm:"varchar(255)"` - Created time.Time `xorm:"created"` - } - - assert.NoError(t, testEngine.Sync2(new(TestTx))) - - engine := testEngine.(*Engine) - - // will success - engine.Transaction(func(session *Session) (interface{}, error) { - _, err := session.Insert(TestTx{Msg: "hi"}) - assert.NoError(t, err) - - return nil, nil - }) - - has, err := engine.Exist(&TestTx{Msg: "hi"}) - assert.NoError(t, err) - assert.EqualValues(t, true, has) - - // will rollback - _, err = engine.Transaction(func(session *Session) (interface{}, error) { - _, err := session.Insert(TestTx{Msg: "hello"}) - assert.NoError(t, err) - - return nil, fmt.Errorf("rollback") - }) - assert.Error(t, err) - - has, err = engine.Exist(&TestTx{Msg: "hello"}) - assert.NoError(t, err) - assert.EqualValues(t, false, has) -} diff --git a/vendor/github.com/syndtr/goleveldb/.travis.yml b/vendor/github.com/syndtr/goleveldb/.travis.yml new file mode 100644 index 0000000..66c3078 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/.travis.yml @@ -0,0 +1,12 @@ +language: go + +go: + - 1.9.x + - 1.10.x + - 1.11.x + - tip + +script: + - go vet ./... + - go test -timeout 1h ./... + - go test -timeout 30m -race -run "TestDB_(Concurrent|GoleveldbIssue74)" ./leveldb diff --git a/vendor/github.com/syndtr/goleveldb/LICENSE b/vendor/github.com/syndtr/goleveldb/LICENSE new file mode 100644 index 0000000..4a772d1 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/LICENSE @@ -0,0 +1,24 @@ +Copyright 2012 Suryandaru Triandana +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/syndtr/goleveldb/README.md b/vendor/github.com/syndtr/goleveldb/README.md new file mode 100644 index 0000000..41a4761 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/README.md @@ -0,0 +1,107 @@ +This is an implementation of the [LevelDB key/value database](http:code.google.com/p/leveldb) in the [Go programming language](http:golang.org). + +[![Build Status](https://travis-ci.org/syndtr/goleveldb.png?branch=master)](https://travis-ci.org/syndtr/goleveldb) + +Installation +----------- + + go get github.com/syndtr/goleveldb/leveldb + +Requirements +----------- + +* Need at least `go1.5` or newer. + +Usage +----------- + +Create or open a database: +```go +// The returned DB instance is safe for concurrent use. Which mean that all +// DB's methods may be called concurrently from multiple goroutine. +db, err := leveldb.OpenFile("path/to/db", nil) +... +defer db.Close() +... +``` +Read or modify the database content: +```go +// Remember that the contents of the returned slice should not be modified. +data, err := db.Get([]byte("key"), nil) +... +err = db.Put([]byte("key"), []byte("value"), nil) +... +err = db.Delete([]byte("key"), nil) +... +``` + +Iterate over database content: +```go +iter := db.NewIterator(nil, nil) +for iter.Next() { + // Remember that the contents of the returned slice should not be modified, and + // only valid until the next call to Next. + key := iter.Key() + value := iter.Value() + ... +} +iter.Release() +err = iter.Error() +... +``` +Seek-then-Iterate: +```go +iter := db.NewIterator(nil, nil) +for ok := iter.Seek(key); ok; ok = iter.Next() { + // Use key/value. + ... +} +iter.Release() +err = iter.Error() +... +``` +Iterate over subset of database content: +```go +iter := db.NewIterator(&util.Range{Start: []byte("foo"), Limit: []byte("xoo")}, nil) +for iter.Next() { + // Use key/value. + ... +} +iter.Release() +err = iter.Error() +... +``` +Iterate over subset of database content with a particular prefix: +```go +iter := db.NewIterator(util.BytesPrefix([]byte("foo-")), nil) +for iter.Next() { + // Use key/value. + ... +} +iter.Release() +err = iter.Error() +... +``` +Batch writes: +```go +batch := new(leveldb.Batch) +batch.Put([]byte("foo"), []byte("value")) +batch.Put([]byte("bar"), []byte("another value")) +batch.Delete([]byte("baz")) +err = db.Write(batch, nil) +... +``` +Use bloom filter: +```go +o := &opt.Options{ + Filter: filter.NewBloomFilter(10), +} +db, err := leveldb.OpenFile("path/to/db", o) +... +defer db.Close() +... +``` +Documentation +----------- + +You can read package documentation [here](http:godoc.org/github.com/syndtr/goleveldb). diff --git a/vendor/github.com/syndtr/goleveldb/go.mod b/vendor/github.com/syndtr/goleveldb/go.mod new file mode 100644 index 0000000..ddb3fc7 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/go.mod @@ -0,0 +1,9 @@ +module github.com/syndtr/goleveldb + +go 1.14 + +require ( + github.com/golang/snappy v0.0.1 + github.com/onsi/ginkgo v1.7.0 + github.com/onsi/gomega v1.4.3 +) diff --git a/vendor/github.com/syndtr/goleveldb/go.sum b/vendor/github.com/syndtr/goleveldb/go.sum new file mode 100644 index 0000000..13eccb4 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/go.sum @@ -0,0 +1,25 @@ +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/batch.go b/vendor/github.com/syndtr/goleveldb/leveldb/batch.go new file mode 100644 index 0000000..823be93 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/batch.go @@ -0,0 +1,354 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +// ErrBatchCorrupted records reason of batch corruption. This error will be +// wrapped with errors.ErrCorrupted. +type ErrBatchCorrupted struct { + Reason string +} + +func (e *ErrBatchCorrupted) Error() string { + return fmt.Sprintf("leveldb: batch corrupted: %s", e.Reason) +} + +func newErrBatchCorrupted(reason string) error { + return errors.NewErrCorrupted(storage.FileDesc{}, &ErrBatchCorrupted{reason}) +} + +const ( + batchHeaderLen = 8 + 4 + batchGrowRec = 3000 + batchBufioSize = 16 +) + +// BatchReplay wraps basic batch operations. +type BatchReplay interface { + Put(key, value []byte) + Delete(key []byte) +} + +type batchIndex struct { + keyType keyType + keyPos, keyLen int + valuePos, valueLen int +} + +func (index batchIndex) k(data []byte) []byte { + return data[index.keyPos : index.keyPos+index.keyLen] +} + +func (index batchIndex) v(data []byte) []byte { + if index.valueLen != 0 { + return data[index.valuePos : index.valuePos+index.valueLen] + } + return nil +} + +func (index batchIndex) kv(data []byte) (key, value []byte) { + return index.k(data), index.v(data) +} + +// Batch is a write batch. +type Batch struct { + data []byte + index []batchIndex + + // internalLen is sums of key/value pair length plus 8-bytes internal key. + internalLen int +} + +func (b *Batch) grow(n int) { + o := len(b.data) + if cap(b.data)-o < n { + div := 1 + if len(b.index) > batchGrowRec { + div = len(b.index) / batchGrowRec + } + ndata := make([]byte, o, o+n+o/div) + copy(ndata, b.data) + b.data = ndata + } +} + +func (b *Batch) appendRec(kt keyType, key, value []byte) { + n := 1 + binary.MaxVarintLen32 + len(key) + if kt == keyTypeVal { + n += binary.MaxVarintLen32 + len(value) + } + b.grow(n) + index := batchIndex{keyType: kt} + o := len(b.data) + data := b.data[:o+n] + data[o] = byte(kt) + o++ + o += binary.PutUvarint(data[o:], uint64(len(key))) + index.keyPos = o + index.keyLen = len(key) + o += copy(data[o:], key) + if kt == keyTypeVal { + o += binary.PutUvarint(data[o:], uint64(len(value))) + index.valuePos = o + index.valueLen = len(value) + o += copy(data[o:], value) + } + b.data = data[:o] + b.index = append(b.index, index) + b.internalLen += index.keyLen + index.valueLen + 8 +} + +// Put appends 'put operation' of the given key/value pair to the batch. +// It is safe to modify the contents of the argument after Put returns but not +// before. +func (b *Batch) Put(key, value []byte) { + b.appendRec(keyTypeVal, key, value) +} + +// Delete appends 'delete operation' of the given key to the batch. +// It is safe to modify the contents of the argument after Delete returns but +// not before. +func (b *Batch) Delete(key []byte) { + b.appendRec(keyTypeDel, key, nil) +} + +// Dump dumps batch contents. The returned slice can be loaded into the +// batch using Load method. +// The returned slice is not its own copy, so the contents should not be +// modified. +func (b *Batch) Dump() []byte { + return b.data +} + +// Load loads given slice into the batch. Previous contents of the batch +// will be discarded. +// The given slice will not be copied and will be used as batch buffer, so +// it is not safe to modify the contents of the slice. +func (b *Batch) Load(data []byte) error { + return b.decode(data, -1) +} + +// Replay replays batch contents. +func (b *Batch) Replay(r BatchReplay) error { + for _, index := range b.index { + switch index.keyType { + case keyTypeVal: + r.Put(index.k(b.data), index.v(b.data)) + case keyTypeDel: + r.Delete(index.k(b.data)) + } + } + return nil +} + +// Len returns number of records in the batch. +func (b *Batch) Len() int { + return len(b.index) +} + +// Reset resets the batch. +func (b *Batch) Reset() { + b.data = b.data[:0] + b.index = b.index[:0] + b.internalLen = 0 +} + +func (b *Batch) replayInternal(fn func(i int, kt keyType, k, v []byte) error) error { + for i, index := range b.index { + if err := fn(i, index.keyType, index.k(b.data), index.v(b.data)); err != nil { + return err + } + } + return nil +} + +func (b *Batch) append(p *Batch) { + ob := len(b.data) + oi := len(b.index) + b.data = append(b.data, p.data...) + b.index = append(b.index, p.index...) + b.internalLen += p.internalLen + + // Updating index offset. + if ob != 0 { + for ; oi < len(b.index); oi++ { + index := &b.index[oi] + index.keyPos += ob + if index.valueLen != 0 { + index.valuePos += ob + } + } + } +} + +func (b *Batch) decode(data []byte, expectedLen int) error { + b.data = data + b.index = b.index[:0] + b.internalLen = 0 + err := decodeBatch(data, func(i int, index batchIndex) error { + b.index = append(b.index, index) + b.internalLen += index.keyLen + index.valueLen + 8 + return nil + }) + if err != nil { + return err + } + if expectedLen >= 0 && len(b.index) != expectedLen { + return newErrBatchCorrupted(fmt.Sprintf("invalid records length: %d vs %d", expectedLen, len(b.index))) + } + return nil +} + +func (b *Batch) putMem(seq uint64, mdb *memdb.DB) error { + var ik []byte + for i, index := range b.index { + ik = makeInternalKey(ik, index.k(b.data), seq+uint64(i), index.keyType) + if err := mdb.Put(ik, index.v(b.data)); err != nil { + return err + } + } + return nil +} + +func (b *Batch) revertMem(seq uint64, mdb *memdb.DB) error { + var ik []byte + for i, index := range b.index { + ik = makeInternalKey(ik, index.k(b.data), seq+uint64(i), index.keyType) + if err := mdb.Delete(ik); err != nil { + return err + } + } + return nil +} + +func newBatch() interface{} { + return &Batch{} +} + +// MakeBatch returns empty batch with preallocated buffer. +func MakeBatch(n int) *Batch { + return &Batch{data: make([]byte, 0, n)} +} + +func decodeBatch(data []byte, fn func(i int, index batchIndex) error) error { + var index batchIndex + for i, o := 0, 0; o < len(data); i++ { + // Key type. + index.keyType = keyType(data[o]) + if index.keyType > keyTypeVal { + return newErrBatchCorrupted(fmt.Sprintf("bad record: invalid type %#x", uint(index.keyType))) + } + o++ + + // Key. + x, n := binary.Uvarint(data[o:]) + o += n + if n <= 0 || o+int(x) > len(data) { + return newErrBatchCorrupted("bad record: invalid key length") + } + index.keyPos = o + index.keyLen = int(x) + o += index.keyLen + + // Value. + if index.keyType == keyTypeVal { + x, n = binary.Uvarint(data[o:]) + o += n + if n <= 0 || o+int(x) > len(data) { + return newErrBatchCorrupted("bad record: invalid value length") + } + index.valuePos = o + index.valueLen = int(x) + o += index.valueLen + } else { + index.valuePos = 0 + index.valueLen = 0 + } + + if err := fn(i, index); err != nil { + return err + } + } + return nil +} + +func decodeBatchToMem(data []byte, expectSeq uint64, mdb *memdb.DB) (seq uint64, batchLen int, err error) { + seq, batchLen, err = decodeBatchHeader(data) + if err != nil { + return 0, 0, err + } + if seq < expectSeq { + return 0, 0, newErrBatchCorrupted("invalid sequence number") + } + data = data[batchHeaderLen:] + var ik []byte + var decodedLen int + err = decodeBatch(data, func(i int, index batchIndex) error { + if i >= batchLen { + return newErrBatchCorrupted("invalid records length") + } + ik = makeInternalKey(ik, index.k(data), seq+uint64(i), index.keyType) + if err := mdb.Put(ik, index.v(data)); err != nil { + return err + } + decodedLen++ + return nil + }) + if err == nil && decodedLen != batchLen { + err = newErrBatchCorrupted(fmt.Sprintf("invalid records length: %d vs %d", batchLen, decodedLen)) + } + return +} + +func encodeBatchHeader(dst []byte, seq uint64, batchLen int) []byte { + dst = ensureBuffer(dst, batchHeaderLen) + binary.LittleEndian.PutUint64(dst, seq) + binary.LittleEndian.PutUint32(dst[8:], uint32(batchLen)) + return dst +} + +func decodeBatchHeader(data []byte) (seq uint64, batchLen int, err error) { + if len(data) < batchHeaderLen { + return 0, 0, newErrBatchCorrupted("too short") + } + + seq = binary.LittleEndian.Uint64(data) + batchLen = int(binary.LittleEndian.Uint32(data[8:])) + if batchLen < 0 { + return 0, 0, newErrBatchCorrupted("invalid records length") + } + return +} + +func batchesLen(batches []*Batch) int { + batchLen := 0 + for _, batch := range batches { + batchLen += batch.Len() + } + return batchLen +} + +func writeBatchesWithHeader(wr io.Writer, batches []*Batch, seq uint64) error { + if _, err := wr.Write(encodeBatchHeader(nil, seq, batchesLen(batches))); err != nil { + return err + } + for _, batch := range batches { + if _, err := wr.Write(batch.data); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/batch_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/batch_test.go new file mode 100644 index 0000000..62775f7 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/batch_test.go @@ -0,0 +1,147 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "testing" + "testing/quick" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestBatchHeader(t *testing.T) { + f := func(seq uint64, length uint32) bool { + encoded := encodeBatchHeader(nil, seq, int(length)) + decSeq, decLength, err := decodeBatchHeader(encoded) + return err == nil && decSeq == seq && decLength == int(length) + } + config := &quick.Config{ + Rand: testutil.NewRand(), + } + if err := quick.Check(f, config); err != nil { + t.Error(err) + } +} + +type batchKV struct { + kt keyType + k, v []byte +} + +func TestBatch(t *testing.T) { + var ( + kvs []batchKV + internalLen int + ) + batch := new(Batch) + rbatch := new(Batch) + abatch := new(Batch) + testBatch := func(i int, kt keyType, k, v []byte) error { + kv := kvs[i] + if kv.kt != kt { + return fmt.Errorf("invalid key type, index=%d: %d vs %d", i, kv.kt, kt) + } + if !bytes.Equal(kv.k, k) { + return fmt.Errorf("invalid key, index=%d", i) + } + if !bytes.Equal(kv.v, v) { + return fmt.Errorf("invalid value, index=%d", i) + } + return nil + } + f := func(ktr uint8, k, v []byte) bool { + kt := keyType(ktr % 2) + if kt == keyTypeVal { + batch.Put(k, v) + rbatch.Put(k, v) + kvs = append(kvs, batchKV{kt: kt, k: k, v: v}) + internalLen += len(k) + len(v) + 8 + } else { + batch.Delete(k) + rbatch.Delete(k) + kvs = append(kvs, batchKV{kt: kt, k: k}) + internalLen += len(k) + 8 + } + if batch.Len() != len(kvs) { + t.Logf("batch.Len: %d vs %d", len(kvs), batch.Len()) + return false + } + if batch.internalLen != internalLen { + t.Logf("abatch.internalLen: %d vs %d", internalLen, batch.internalLen) + return false + } + if len(kvs)%1000 == 0 { + if err := batch.replayInternal(testBatch); err != nil { + t.Logf("batch.replayInternal: %v", err) + return false + } + + abatch.append(rbatch) + rbatch.Reset() + if abatch.Len() != len(kvs) { + t.Logf("abatch.Len: %d vs %d", len(kvs), abatch.Len()) + return false + } + if abatch.internalLen != internalLen { + t.Logf("abatch.internalLen: %d vs %d", internalLen, abatch.internalLen) + return false + } + if err := abatch.replayInternal(testBatch); err != nil { + t.Logf("abatch.replayInternal: %v", err) + return false + } + + nbatch := new(Batch) + if err := nbatch.Load(batch.Dump()); err != nil { + t.Logf("nbatch.Load: %v", err) + return false + } + if nbatch.Len() != len(kvs) { + t.Logf("nbatch.Len: %d vs %d", len(kvs), nbatch.Len()) + return false + } + if nbatch.internalLen != internalLen { + t.Logf("nbatch.internalLen: %d vs %d", internalLen, nbatch.internalLen) + return false + } + if err := nbatch.replayInternal(testBatch); err != nil { + t.Logf("nbatch.replayInternal: %v", err) + return false + } + } + if len(kvs)%10000 == 0 { + nbatch := new(Batch) + if err := batch.Replay(nbatch); err != nil { + t.Logf("batch.Replay: %v", err) + return false + } + if nbatch.Len() != len(kvs) { + t.Logf("nbatch.Len: %d vs %d", len(kvs), nbatch.Len()) + return false + } + if nbatch.internalLen != internalLen { + t.Logf("nbatch.internalLen: %d vs %d", internalLen, nbatch.internalLen) + return false + } + if err := nbatch.replayInternal(testBatch); err != nil { + t.Logf("nbatch.replayInternal: %v", err) + return false + } + } + return true + } + config := &quick.Config{ + MaxCount: 40000, + Rand: testutil.NewRand(), + } + if err := quick.Check(f, config); err != nil { + t.Error(err) + } + t.Logf("length=%d internalLen=%d", len(kvs), internalLen) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/bench_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/bench_test.go new file mode 100644 index 0000000..435250c --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/bench_test.go @@ -0,0 +1,507 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "math/rand" + "os" + "path/filepath" + "runtime" + "sync/atomic" + "testing" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +func randomString(r *rand.Rand, n int) []byte { + b := new(bytes.Buffer) + for i := 0; i < n; i++ { + b.WriteByte(' ' + byte(r.Intn(95))) + } + return b.Bytes() +} + +func compressibleStr(r *rand.Rand, frac float32, n int) []byte { + nn := int(float32(n) * frac) + rb := randomString(r, nn) + b := make([]byte, 0, n+nn) + for len(b) < n { + b = append(b, rb...) + } + return b[:n] +} + +type valueGen struct { + src []byte + pos int +} + +func newValueGen(frac float32) *valueGen { + v := new(valueGen) + r := rand.New(rand.NewSource(301)) + v.src = make([]byte, 0, 1048576+100) + for len(v.src) < 1048576 { + v.src = append(v.src, compressibleStr(r, frac, 100)...) + } + return v +} + +func (v *valueGen) get(n int) []byte { + if v.pos+n > len(v.src) { + v.pos = 0 + } + v.pos += n + return v.src[v.pos-n : v.pos] +} + +var benchDB = filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbbench-%d", os.Getuid())) + +type dbBench struct { + b *testing.B + stor storage.Storage + db *DB + + o *opt.Options + ro *opt.ReadOptions + wo *opt.WriteOptions + + keys, values [][]byte +} + +func openDBBench(b *testing.B, noCompress bool) *dbBench { + _, err := os.Stat(benchDB) + if err == nil { + err = os.RemoveAll(benchDB) + if err != nil { + b.Fatal("cannot remove old db: ", err) + } + } + + p := &dbBench{ + b: b, + o: &opt.Options{}, + ro: &opt.ReadOptions{}, + wo: &opt.WriteOptions{}, + } + p.stor, err = storage.OpenFile(benchDB, false) + if err != nil { + b.Fatal("cannot open stor: ", err) + } + if noCompress { + p.o.Compression = opt.NoCompression + } + + p.db, err = Open(p.stor, p.o) + if err != nil { + b.Fatal("cannot open db: ", err) + } + + return p +} + +func (p *dbBench) reopen() { + p.db.Close() + var err error + p.db, err = Open(p.stor, p.o) + if err != nil { + p.b.Fatal("Reopen: got error: ", err) + } +} + +func (p *dbBench) populate(n int) { + p.keys, p.values = make([][]byte, n), make([][]byte, n) + v := newValueGen(0.5) + for i := range p.keys { + p.keys[i], p.values[i] = []byte(fmt.Sprintf("%016d", i)), v.get(100) + } +} + +func (p *dbBench) randomize() { + m := len(p.keys) + times := m * 2 + r1, r2 := rand.New(rand.NewSource(0xdeadbeef)), rand.New(rand.NewSource(0xbeefface)) + for n := 0; n < times; n++ { + i, j := r1.Int()%m, r2.Int()%m + if i == j { + continue + } + p.keys[i], p.keys[j] = p.keys[j], p.keys[i] + p.values[i], p.values[j] = p.values[j], p.values[i] + } +} + +func (p *dbBench) writes(perBatch int) { + b := p.b + db := p.db + + n := len(p.keys) + m := n / perBatch + if n%perBatch > 0 { + m++ + } + batches := make([]Batch, m) + j := 0 + for i := range batches { + first := true + for ; j < n && ((j+1)%perBatch != 0 || first); j++ { + first = false + batches[i].Put(p.keys[j], p.values[j]) + } + } + runtime.GC() + + b.ResetTimer() + b.StartTimer() + for i := range batches { + err := db.Write(&(batches[i]), p.wo) + if err != nil { + b.Fatal("write failed: ", err) + } + } + b.StopTimer() + b.SetBytes(116) +} + +func (p *dbBench) gc() { + p.keys, p.values = nil, nil + runtime.GC() +} + +func (p *dbBench) puts() { + b := p.b + db := p.db + + b.ResetTimer() + b.StartTimer() + for i := range p.keys { + err := db.Put(p.keys[i], p.values[i], p.wo) + if err != nil { + b.Fatal("put failed: ", err) + } + } + b.StopTimer() + b.SetBytes(116) +} + +func (p *dbBench) fill() { + b := p.b + db := p.db + + perBatch := 10000 + batch := new(Batch) + for i, n := 0, len(p.keys); i < n; { + first := true + for ; i < n && ((i+1)%perBatch != 0 || first); i++ { + first = false + batch.Put(p.keys[i], p.values[i]) + } + err := db.Write(batch, p.wo) + if err != nil { + b.Fatal("write failed: ", err) + } + batch.Reset() + } +} + +func (p *dbBench) gets() { + b := p.b + db := p.db + + b.ResetTimer() + for i := range p.keys { + _, err := db.Get(p.keys[i], p.ro) + if err != nil { + b.Error("got error: ", err) + } + } + b.StopTimer() +} + +func (p *dbBench) seeks() { + b := p.b + + iter := p.newIter() + defer iter.Release() + b.ResetTimer() + for i := range p.keys { + if !iter.Seek(p.keys[i]) { + b.Error("value not found for: ", string(p.keys[i])) + } + } + b.StopTimer() +} + +func (p *dbBench) newIter() iterator.Iterator { + iter := p.db.NewIterator(nil, p.ro) + err := iter.Error() + if err != nil { + p.b.Fatal("cannot create iterator: ", err) + } + return iter +} + +func (p *dbBench) close() { + if bp, err := p.db.GetProperty("leveldb.blockpool"); err == nil { + p.b.Log("Block pool stats: ", bp) + } + p.db.Close() + p.stor.Close() + os.RemoveAll(benchDB) + p.db = nil + p.keys = nil + p.values = nil + runtime.GC() +} + +func BenchmarkDBWrite(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBWriteBatch(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1000) + p.close() +} + +func BenchmarkDBWriteUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBWriteBatchUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.writes(1000) + p.close() +} + +func BenchmarkDBWriteRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.randomize() + p.writes(1) + p.close() +} + +func BenchmarkDBWriteRandomSync(b *testing.B) { + p := openDBBench(b, false) + p.wo.Sync = true + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBOverwrite(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.writes(1) + p.close() +} + +func BenchmarkDBOverwriteRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.randomize() + p.writes(1) + p.close() +} + +func BenchmarkDBPut(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.puts() + p.close() +} + +func BenchmarkDBRead(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadGC(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadTable(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.reopen() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadReverse(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + iter.Last() + for iter.Prev() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadReverseTable(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.reopen() + p.gc() + + iter := p.newIter() + b.ResetTimer() + iter.Last() + for iter.Prev() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBSeek(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.seeks() + p.close() +} + +func BenchmarkDBSeekRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.randomize() + p.seeks() + p.close() +} + +func BenchmarkDBGet(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gets() + p.close() +} + +func BenchmarkDBGetRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.randomize() + p.gets() + p.close() +} + +func BenchmarkDBReadConcurrent(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + defer p.close() + + b.ResetTimer() + b.SetBytes(116) + + b.RunParallel(func(pb *testing.PB) { + iter := p.newIter() + defer iter.Release() + for pb.Next() && iter.Next() { + } + }) +} + +func BenchmarkDBReadConcurrent2(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + defer p.close() + + b.ResetTimer() + b.SetBytes(116) + + var dir uint32 + b.RunParallel(func(pb *testing.PB) { + iter := p.newIter() + defer iter.Release() + if atomic.AddUint32(&dir, 1)%2 == 0 { + for pb.Next() && iter.Next() { + } + } else { + if pb.Next() && iter.Last() { + for pb.Next() && iter.Prev() { + } + } + } + }) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/cache/bench_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/cache/bench_test.go new file mode 100644 index 0000000..89aef69 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/cache/bench_test.go @@ -0,0 +1,29 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package cache + +import ( + "math/rand" + "testing" + "time" +) + +func BenchmarkLRUCache(b *testing.B) { + c := NewCache(NewLRU(10000)) + + b.SetParallelism(10) + b.RunParallel(func(pb *testing.PB) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for pb.Next() { + key := uint64(r.Intn(1000000)) + c.Get(0, key, func() (int, Value) { + return 1, key + }).Release() + } + }) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache.go b/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache.go new file mode 100644 index 0000000..c36ad32 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache.go @@ -0,0 +1,704 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package cache provides interface and implementation of a cache algorithms. +package cache + +import ( + "sync" + "sync/atomic" + "unsafe" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Cacher provides interface to implements a caching functionality. +// An implementation must be safe for concurrent use. +type Cacher interface { + // Capacity returns cache capacity. + Capacity() int + + // SetCapacity sets cache capacity. + SetCapacity(capacity int) + + // Promote promotes the 'cache node'. + Promote(n *Node) + + // Ban evicts the 'cache node' and prevent subsequent 'promote'. + Ban(n *Node) + + // Evict evicts the 'cache node'. + Evict(n *Node) + + // EvictNS evicts 'cache node' with the given namespace. + EvictNS(ns uint64) + + // EvictAll evicts all 'cache node'. + EvictAll() + + // Close closes the 'cache tree' + Close() error +} + +// Value is a 'cacheable object'. It may implements util.Releaser, if +// so the the Release method will be called once object is released. +type Value interface{} + +// NamespaceGetter provides convenient wrapper for namespace. +type NamespaceGetter struct { + Cache *Cache + NS uint64 +} + +// Get simply calls Cache.Get() method. +func (g *NamespaceGetter) Get(key uint64, setFunc func() (size int, value Value)) *Handle { + return g.Cache.Get(g.NS, key, setFunc) +} + +// The hash tables implementation is based on: +// "Dynamic-Sized Nonblocking Hash Tables", by Yujie Liu, +// Kunlong Zhang, and Michael Spear. +// ACM Symposium on Principles of Distributed Computing, Jul 2014. + +const ( + mInitialSize = 1 << 4 + mOverflowThreshold = 1 << 5 + mOverflowGrowThreshold = 1 << 7 +) + +type mBucket struct { + mu sync.Mutex + node []*Node + frozen bool +} + +func (b *mBucket) freeze() []*Node { + b.mu.Lock() + defer b.mu.Unlock() + if !b.frozen { + b.frozen = true + } + return b.node +} + +func (b *mBucket) get(r *Cache, h *mNode, hash uint32, ns, key uint64, noset bool) (done, added bool, n *Node) { + b.mu.Lock() + + if b.frozen { + b.mu.Unlock() + return + } + + // Scan the node. + for _, n := range b.node { + if n.hash == hash && n.ns == ns && n.key == key { + atomic.AddInt32(&n.ref, 1) + b.mu.Unlock() + return true, false, n + } + } + + // Get only. + if noset { + b.mu.Unlock() + return true, false, nil + } + + // Create node. + n = &Node{ + r: r, + hash: hash, + ns: ns, + key: key, + ref: 1, + } + // Add node to bucket. + b.node = append(b.node, n) + bLen := len(b.node) + b.mu.Unlock() + + // Update counter. + grow := atomic.AddInt32(&r.nodes, 1) >= h.growThreshold + if bLen > mOverflowThreshold { + grow = grow || atomic.AddInt32(&h.overflow, 1) >= mOverflowGrowThreshold + } + + // Grow. + if grow && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) { + nhLen := len(h.buckets) << 1 + nh := &mNode{ + buckets: make([]unsafe.Pointer, nhLen), + mask: uint32(nhLen) - 1, + pred: unsafe.Pointer(h), + growThreshold: int32(nhLen * mOverflowThreshold), + shrinkThreshold: int32(nhLen >> 1), + } + ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh)) + if !ok { + panic("BUG: failed swapping head") + } + go nh.initBuckets() + } + + return true, true, n +} + +func (b *mBucket) delete(r *Cache, h *mNode, hash uint32, ns, key uint64) (done, deleted bool) { + b.mu.Lock() + + if b.frozen { + b.mu.Unlock() + return + } + + // Scan the node. + var ( + n *Node + bLen int + ) + for i := range b.node { + n = b.node[i] + if n.ns == ns && n.key == key { + if atomic.LoadInt32(&n.ref) == 0 { + deleted = true + + // Call releaser. + if n.value != nil { + if r, ok := n.value.(util.Releaser); ok { + r.Release() + } + n.value = nil + } + + // Remove node from bucket. + b.node = append(b.node[:i], b.node[i+1:]...) + bLen = len(b.node) + } + break + } + } + b.mu.Unlock() + + if deleted { + // Call OnDel. + for _, f := range n.onDel { + f() + } + + // Update counter. + atomic.AddInt32(&r.size, int32(n.size)*-1) + shrink := atomic.AddInt32(&r.nodes, -1) < h.shrinkThreshold + if bLen >= mOverflowThreshold { + atomic.AddInt32(&h.overflow, -1) + } + + // Shrink. + if shrink && len(h.buckets) > mInitialSize && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) { + nhLen := len(h.buckets) >> 1 + nh := &mNode{ + buckets: make([]unsafe.Pointer, nhLen), + mask: uint32(nhLen) - 1, + pred: unsafe.Pointer(h), + growThreshold: int32(nhLen * mOverflowThreshold), + shrinkThreshold: int32(nhLen >> 1), + } + ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh)) + if !ok { + panic("BUG: failed swapping head") + } + go nh.initBuckets() + } + } + + return true, deleted +} + +type mNode struct { + buckets []unsafe.Pointer // []*mBucket + mask uint32 + pred unsafe.Pointer // *mNode + resizeInProgess int32 + + overflow int32 + growThreshold int32 + shrinkThreshold int32 +} + +func (n *mNode) initBucket(i uint32) *mBucket { + if b := (*mBucket)(atomic.LoadPointer(&n.buckets[i])); b != nil { + return b + } + + p := (*mNode)(atomic.LoadPointer(&n.pred)) + if p != nil { + var node []*Node + if n.mask > p.mask { + // Grow. + pb := (*mBucket)(atomic.LoadPointer(&p.buckets[i&p.mask])) + if pb == nil { + pb = p.initBucket(i & p.mask) + } + m := pb.freeze() + // Split nodes. + for _, x := range m { + if x.hash&n.mask == i { + node = append(node, x) + } + } + } else { + // Shrink. + pb0 := (*mBucket)(atomic.LoadPointer(&p.buckets[i])) + if pb0 == nil { + pb0 = p.initBucket(i) + } + pb1 := (*mBucket)(atomic.LoadPointer(&p.buckets[i+uint32(len(n.buckets))])) + if pb1 == nil { + pb1 = p.initBucket(i + uint32(len(n.buckets))) + } + m0 := pb0.freeze() + m1 := pb1.freeze() + // Merge nodes. + node = make([]*Node, 0, len(m0)+len(m1)) + node = append(node, m0...) + node = append(node, m1...) + } + b := &mBucket{node: node} + if atomic.CompareAndSwapPointer(&n.buckets[i], nil, unsafe.Pointer(b)) { + if len(node) > mOverflowThreshold { + atomic.AddInt32(&n.overflow, int32(len(node)-mOverflowThreshold)) + } + return b + } + } + + return (*mBucket)(atomic.LoadPointer(&n.buckets[i])) +} + +func (n *mNode) initBuckets() { + for i := range n.buckets { + n.initBucket(uint32(i)) + } + atomic.StorePointer(&n.pred, nil) +} + +// Cache is a 'cache map'. +type Cache struct { + mu sync.RWMutex + mHead unsafe.Pointer // *mNode + nodes int32 + size int32 + cacher Cacher + closed bool +} + +// NewCache creates a new 'cache map'. The cacher is optional and +// may be nil. +func NewCache(cacher Cacher) *Cache { + h := &mNode{ + buckets: make([]unsafe.Pointer, mInitialSize), + mask: mInitialSize - 1, + growThreshold: int32(mInitialSize * mOverflowThreshold), + shrinkThreshold: 0, + } + for i := range h.buckets { + h.buckets[i] = unsafe.Pointer(&mBucket{}) + } + r := &Cache{ + mHead: unsafe.Pointer(h), + cacher: cacher, + } + return r +} + +func (r *Cache) getBucket(hash uint32) (*mNode, *mBucket) { + h := (*mNode)(atomic.LoadPointer(&r.mHead)) + i := hash & h.mask + b := (*mBucket)(atomic.LoadPointer(&h.buckets[i])) + if b == nil { + b = h.initBucket(i) + } + return h, b +} + +func (r *Cache) delete(n *Node) bool { + for { + h, b := r.getBucket(n.hash) + done, deleted := b.delete(r, h, n.hash, n.ns, n.key) + if done { + return deleted + } + } +} + +// Nodes returns number of 'cache node' in the map. +func (r *Cache) Nodes() int { + return int(atomic.LoadInt32(&r.nodes)) +} + +// Size returns sums of 'cache node' size in the map. +func (r *Cache) Size() int { + return int(atomic.LoadInt32(&r.size)) +} + +// Capacity returns cache capacity. +func (r *Cache) Capacity() int { + if r.cacher == nil { + return 0 + } + return r.cacher.Capacity() +} + +// SetCapacity sets cache capacity. +func (r *Cache) SetCapacity(capacity int) { + if r.cacher != nil { + r.cacher.SetCapacity(capacity) + } +} + +// Get gets 'cache node' with the given namespace and key. +// If cache node is not found and setFunc is not nil, Get will atomically creates +// the 'cache node' by calling setFunc. Otherwise Get will returns nil. +// +// The returned 'cache handle' should be released after use by calling Release +// method. +func (r *Cache) Get(ns, key uint64, setFunc func() (size int, value Value)) *Handle { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return nil + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, setFunc == nil) + if done { + if n != nil { + n.mu.Lock() + if n.value == nil { + if setFunc == nil { + n.mu.Unlock() + n.unref() + return nil + } + + n.size, n.value = setFunc() + if n.value == nil { + n.size = 0 + n.mu.Unlock() + n.unref() + return nil + } + atomic.AddInt32(&r.size, int32(n.size)) + } + n.mu.Unlock() + if r.cacher != nil { + r.cacher.Promote(n) + } + return &Handle{unsafe.Pointer(n)} + } + + break + } + } + return nil +} + +// Delete removes and ban 'cache node' with the given namespace and key. +// A banned 'cache node' will never inserted into the 'cache tree'. Ban +// only attributed to the particular 'cache node', so when a 'cache node' +// is recreated it will not be banned. +// +// If onDel is not nil, then it will be executed if such 'cache node' +// doesn't exist or once the 'cache node' is released. +// +// Delete return true is such 'cache node' exist. +func (r *Cache) Delete(ns, key uint64, onDel func()) bool { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return false + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, true) + if done { + if n != nil { + if onDel != nil { + n.mu.Lock() + n.onDel = append(n.onDel, onDel) + n.mu.Unlock() + } + if r.cacher != nil { + r.cacher.Ban(n) + } + n.unref() + return true + } + + break + } + } + + if onDel != nil { + onDel() + } + + return false +} + +// Evict evicts 'cache node' with the given namespace and key. This will +// simply call Cacher.Evict. +// +// Evict return true is such 'cache node' exist. +func (r *Cache) Evict(ns, key uint64) bool { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return false + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, true) + if done { + if n != nil { + if r.cacher != nil { + r.cacher.Evict(n) + } + n.unref() + return true + } + + break + } + } + + return false +} + +// EvictNS evicts 'cache node' with the given namespace. This will +// simply call Cacher.EvictNS. +func (r *Cache) EvictNS(ns uint64) { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return + } + + if r.cacher != nil { + r.cacher.EvictNS(ns) + } +} + +// EvictAll evicts all 'cache node'. This will simply call Cacher.EvictAll. +func (r *Cache) EvictAll() { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return + } + + if r.cacher != nil { + r.cacher.EvictAll() + } +} + +// Close closes the 'cache map' and forcefully releases all 'cache node'. +func (r *Cache) Close() error { + r.mu.Lock() + if !r.closed { + r.closed = true + + h := (*mNode)(r.mHead) + h.initBuckets() + + for i := range h.buckets { + b := (*mBucket)(h.buckets[i]) + for _, n := range b.node { + // Call releaser. + if n.value != nil { + if r, ok := n.value.(util.Releaser); ok { + r.Release() + } + n.value = nil + } + + // Call OnDel. + for _, f := range n.onDel { + f() + } + n.onDel = nil + } + } + } + r.mu.Unlock() + + // Avoid deadlock. + if r.cacher != nil { + if err := r.cacher.Close(); err != nil { + return err + } + } + return nil +} + +// CloseWeak closes the 'cache map' and evict all 'cache node' from cacher, but +// unlike Close it doesn't forcefully releases 'cache node'. +func (r *Cache) CloseWeak() error { + r.mu.Lock() + if !r.closed { + r.closed = true + } + r.mu.Unlock() + + // Avoid deadlock. + if r.cacher != nil { + r.cacher.EvictAll() + if err := r.cacher.Close(); err != nil { + return err + } + } + return nil +} + +// Node is a 'cache node'. +type Node struct { + r *Cache + + hash uint32 + ns, key uint64 + + mu sync.Mutex + size int + value Value + + ref int32 + onDel []func() + + CacheData unsafe.Pointer +} + +// NS returns this 'cache node' namespace. +func (n *Node) NS() uint64 { + return n.ns +} + +// Key returns this 'cache node' key. +func (n *Node) Key() uint64 { + return n.key +} + +// Size returns this 'cache node' size. +func (n *Node) Size() int { + return n.size +} + +// Value returns this 'cache node' value. +func (n *Node) Value() Value { + return n.value +} + +// Ref returns this 'cache node' ref counter. +func (n *Node) Ref() int32 { + return atomic.LoadInt32(&n.ref) +} + +// GetHandle returns an handle for this 'cache node'. +func (n *Node) GetHandle() *Handle { + if atomic.AddInt32(&n.ref, 1) <= 1 { + panic("BUG: Node.GetHandle on zero ref") + } + return &Handle{unsafe.Pointer(n)} +} + +func (n *Node) unref() { + if atomic.AddInt32(&n.ref, -1) == 0 { + n.r.delete(n) + } +} + +func (n *Node) unrefLocked() { + if atomic.AddInt32(&n.ref, -1) == 0 { + n.r.mu.RLock() + if !n.r.closed { + n.r.delete(n) + } + n.r.mu.RUnlock() + } +} + +// Handle is a 'cache handle' of a 'cache node'. +type Handle struct { + n unsafe.Pointer // *Node +} + +// Value returns the value of the 'cache node'. +func (h *Handle) Value() Value { + n := (*Node)(atomic.LoadPointer(&h.n)) + if n != nil { + return n.value + } + return nil +} + +// Release releases this 'cache handle'. +// It is safe to call release multiple times. +func (h *Handle) Release() { + nPtr := atomic.LoadPointer(&h.n) + if nPtr != nil && atomic.CompareAndSwapPointer(&h.n, nPtr, nil) { + n := (*Node)(nPtr) + n.unrefLocked() + } +} + +func murmur32(ns, key uint64, seed uint32) uint32 { + const ( + m = uint32(0x5bd1e995) + r = 24 + ) + + k1 := uint32(ns >> 32) + k2 := uint32(ns) + k3 := uint32(key >> 32) + k4 := uint32(key) + + k1 *= m + k1 ^= k1 >> r + k1 *= m + + k2 *= m + k2 ^= k2 >> r + k2 *= m + + k3 *= m + k3 ^= k3 >> r + k3 *= m + + k4 *= m + k4 ^= k4 >> r + k4 *= m + + h := seed + + h *= m + h ^= k1 + h *= m + h ^= k2 + h *= m + h ^= k3 + h *= m + h ^= k4 + + h ^= h >> 13 + h *= m + h ^= h >> 15 + + return h +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go new file mode 100644 index 0000000..6b017bd --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go @@ -0,0 +1,563 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package cache + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" +) + +type int32o int32 + +func (o *int32o) acquire() { + if atomic.AddInt32((*int32)(o), 1) != 1 { + panic("BUG: invalid ref") + } +} + +func (o *int32o) Release() { + if atomic.AddInt32((*int32)(o), -1) != 0 { + panic("BUG: invalid ref") + } +} + +type releaserFunc struct { + fn func() + value Value +} + +func (r releaserFunc) Release() { + if r.fn != nil { + r.fn() + } +} + +func set(c *Cache, ns, key uint64, value Value, charge int, relf func()) *Handle { + return c.Get(ns, key, func() (int, Value) { + if relf != nil { + return charge, releaserFunc{relf, value} + } + return charge, value + }) +} + +type cacheMapTestParams struct { + nobjects, nhandles, concurrent, repeat int +} + +func TestCacheMap(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + var params []cacheMapTestParams + if testing.Short() { + params = []cacheMapTestParams{ + {1000, 100, 20, 3}, + {10000, 300, 50, 10}, + } + } else { + params = []cacheMapTestParams{ + {10000, 400, 50, 3}, + {100000, 1000, 100, 10}, + } + } + + var ( + objects [][]int32o + handles [][]unsafe.Pointer + ) + + for _, x := range params { + objects = append(objects, make([]int32o, x.nobjects)) + handles = append(handles, make([]unsafe.Pointer, x.nhandles)) + } + + c := NewCache(nil) + + wg := new(sync.WaitGroup) + var done int32 + + for ns, x := range params { + for i := 0; i < x.concurrent; i++ { + wg.Add(1) + go func(ns, i, repeat int, objects []int32o, handles []unsafe.Pointer) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for j := len(objects) * repeat; j >= 0; j-- { + key := uint64(r.Intn(len(objects))) + h := c.Get(uint64(ns), key, func() (int, Value) { + o := &objects[key] + o.acquire() + return 1, o + }) + if v := h.Value().(*int32o); v != &objects[key] { + t.Fatalf("#%d invalid value: want=%p got=%p", ns, &objects[key], v) + } + if objects[key] != 1 { + t.Fatalf("#%d invalid object %d: %d", ns, key, objects[key]) + } + if !atomic.CompareAndSwapPointer(&handles[r.Intn(len(handles))], nil, unsafe.Pointer(h)) { + h.Release() + } + } + }(ns, i, x.repeat, objects[ns], handles[ns]) + } + + go func(handles []unsafe.Pointer) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for atomic.LoadInt32(&done) == 0 { + i := r.Intn(len(handles)) + h := (*Handle)(atomic.LoadPointer(&handles[i])) + if h != nil && atomic.CompareAndSwapPointer(&handles[i], unsafe.Pointer(h), nil) { + h.Release() + } + time.Sleep(time.Millisecond) + } + }(handles[ns]) + } + + go func() { + handles := make([]*Handle, 100000) + for atomic.LoadInt32(&done) == 0 { + for i := range handles { + handles[i] = c.Get(999999999, uint64(i), func() (int, Value) { + return 1, 1 + }) + } + for _, h := range handles { + h.Release() + } + } + }() + + wg.Wait() + + atomic.StoreInt32(&done, 1) + + for _, handles0 := range handles { + for i := range handles0 { + h := (*Handle)(atomic.LoadPointer(&handles0[i])) + if h != nil && atomic.CompareAndSwapPointer(&handles0[i], unsafe.Pointer(h), nil) { + h.Release() + } + } + } + + for ns, objects0 := range objects { + for i, o := range objects0 { + if o != 0 { + t.Fatalf("invalid object #%d.%d: ref=%d", ns, i, o) + } + } + } +} + +func TestCacheMap_NodesAndSize(t *testing.T) { + c := NewCache(nil) + if c.Nodes() != 0 { + t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes()) + } + if c.Size() != 0 { + t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size()) + } + set(c, 0, 1, 1, 1, nil) + set(c, 0, 2, 2, 2, nil) + set(c, 1, 1, 3, 3, nil) + set(c, 2, 1, 4, 1, nil) + if c.Nodes() != 4 { + t.Errorf("invalid nodes counter: want=%d got=%d", 4, c.Nodes()) + } + if c.Size() != 7 { + t.Errorf("invalid size counter: want=%d got=%d", 4, c.Size()) + } +} + +func TestLRUCache_Capacity(t *testing.T) { + c := NewCache(NewLRU(10)) + if c.Capacity() != 10 { + t.Errorf("invalid capacity: want=%d got=%d", 10, c.Capacity()) + } + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 2, nil).Release() + set(c, 1, 1, 3, 3, nil).Release() + set(c, 2, 1, 4, 1, nil).Release() + set(c, 2, 2, 5, 1, nil).Release() + set(c, 2, 3, 6, 1, nil).Release() + set(c, 2, 4, 7, 1, nil).Release() + set(c, 2, 5, 8, 1, nil).Release() + if c.Nodes() != 7 { + t.Errorf("invalid nodes counter: want=%d got=%d", 7, c.Nodes()) + } + if c.Size() != 10 { + t.Errorf("invalid size counter: want=%d got=%d", 10, c.Size()) + } + c.SetCapacity(9) + if c.Capacity() != 9 { + t.Errorf("invalid capacity: want=%d got=%d", 9, c.Capacity()) + } + if c.Nodes() != 6 { + t.Errorf("invalid nodes counter: want=%d got=%d", 6, c.Nodes()) + } + if c.Size() != 8 { + t.Errorf("invalid size counter: want=%d got=%d", 8, c.Size()) + } +} + +func TestCacheMap_NilValue(t *testing.T) { + c := NewCache(NewLRU(10)) + h := c.Get(0, 0, func() (size int, value Value) { + return 1, nil + }) + if h != nil { + t.Error("cache handle is non-nil") + } + if c.Nodes() != 0 { + t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes()) + } + if c.Size() != 0 { + t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size()) + } +} + +func TestLRUCache_GetLatency(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + const ( + concurrentSet = 30 + concurrentGet = 3 + duration = 3 * time.Second + delay = 3 * time.Millisecond + maxkey = 100000 + ) + + var ( + set, getHit, getAll int32 + getMaxLatency, getDuration int64 + ) + + c := NewCache(NewLRU(5000)) + wg := &sync.WaitGroup{} + until := time.Now().Add(duration) + for i := 0; i < concurrentSet; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for time.Now().Before(until) { + c.Get(0, uint64(r.Intn(maxkey)), func() (int, Value) { + time.Sleep(delay) + atomic.AddInt32(&set, 1) + return 1, 1 + }).Release() + } + }(i) + } + for i := 0; i < concurrentGet; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for { + mark := time.Now() + if mark.Before(until) { + h := c.Get(0, uint64(r.Intn(maxkey)), nil) + latency := int64(time.Now().Sub(mark)) + m := atomic.LoadInt64(&getMaxLatency) + if latency > m { + atomic.CompareAndSwapInt64(&getMaxLatency, m, latency) + } + atomic.AddInt64(&getDuration, latency) + if h != nil { + atomic.AddInt32(&getHit, 1) + h.Release() + } + atomic.AddInt32(&getAll, 1) + } else { + break + } + } + }(i) + } + + wg.Wait() + getAvglatency := time.Duration(getDuration) / time.Duration(getAll) + t.Logf("set=%d getHit=%d getAll=%d getMaxLatency=%v getAvgLatency=%v", + set, getHit, getAll, time.Duration(getMaxLatency), getAvglatency) + + if getAvglatency > delay/3 { + t.Errorf("get avg latency > %v: got=%v", delay/3, getAvglatency) + } +} + +func TestLRUCache_HitMiss(t *testing.T) { + cases := []struct { + key uint64 + value string + }{ + {1, "vvvvvvvvv"}, + {100, "v1"}, + {0, "v2"}, + {12346, "v3"}, + {777, "v4"}, + {999, "v5"}, + {7654, "v6"}, + {2, "v7"}, + {3, "v8"}, + {9, "v9"}, + } + + setfin := 0 + c := NewCache(NewLRU(1000)) + for i, x := range cases { + set(c, 0, x.key, x.value, len(x.value), func() { + setfin++ + }).Release() + for j, y := range cases { + h := c.Get(0, y.key, nil) + if j <= i { + // should hit + if h == nil { + t.Errorf("case '%d' iteration '%d' is miss", i, j) + } else { + if x := h.Value().(releaserFunc).value.(string); x != y.value { + t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value) + } + } + } else { + // should miss + if h != nil { + t.Errorf("case '%d' iteration '%d' is hit , value '%s'", i, j, h.Value().(releaserFunc).value.(string)) + } + } + if h != nil { + h.Release() + } + } + } + + for i, x := range cases { + finalizerOk := false + c.Delete(0, x.key, func() { + finalizerOk = true + }) + + if !finalizerOk { + t.Errorf("case %d delete finalizer not executed", i) + } + + for j, y := range cases { + h := c.Get(0, y.key, nil) + if j > i { + // should hit + if h == nil { + t.Errorf("case '%d' iteration '%d' is miss", i, j) + } else { + if x := h.Value().(releaserFunc).value.(string); x != y.value { + t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value) + } + } + } else { + // should miss + if h != nil { + t.Errorf("case '%d' iteration '%d' is hit, value '%s'", i, j, h.Value().(releaserFunc).value.(string)) + } + } + if h != nil { + h.Release() + } + } + } + + if setfin != len(cases) { + t.Errorf("some set finalizer may not be executed, want=%d got=%d", len(cases), setfin) + } +} + +func TestLRUCache_Eviction(t *testing.T) { + c := NewCache(NewLRU(12)) + o1 := set(c, 0, 1, 1, 1, nil) + set(c, 0, 2, 2, 1, nil).Release() + set(c, 0, 3, 3, 1, nil).Release() + set(c, 0, 4, 4, 1, nil).Release() + set(c, 0, 5, 5, 1, nil).Release() + if h := c.Get(0, 2, nil); h != nil { // 1,3,4,5,2 + h.Release() + } + set(c, 0, 9, 9, 10, nil).Release() // 5,2,9 + + for _, key := range []uint64{9, 2, 5, 1} { + h := c.Get(0, key, nil) + if h == nil { + t.Errorf("miss for key '%d'", key) + } else { + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } + o1.Release() + for _, key := range []uint64{1, 2, 5} { + h := c.Get(0, key, nil) + if h == nil { + t.Errorf("miss for key '%d'", key) + } else { + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } + for _, key := range []uint64{3, 4, 9} { + h := c.Get(0, key, nil) + if h != nil { + t.Errorf("hit for key '%d'", key) + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } +} + +func TestLRUCache_Evict(t *testing.T) { + c := NewCache(NewLRU(6)) + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 1, nil).Release() + set(c, 1, 1, 4, 1, nil).Release() + set(c, 1, 2, 5, 1, nil).Release() + set(c, 2, 1, 6, 1, nil).Release() + set(c, 2, 2, 7, 1, nil).Release() + + for ns := 0; ns < 3; ns++ { + for key := 1; key < 3; key++ { + if h := c.Get(uint64(ns), uint64(key), nil); h != nil { + h.Release() + } else { + t.Errorf("Cache.Get on #%d.%d return nil", ns, key) + } + } + } + + if ok := c.Evict(0, 1); !ok { + t.Error("first Cache.Evict on #0.1 return false") + } + if ok := c.Evict(0, 1); ok { + t.Error("second Cache.Evict on #0.1 return true") + } + if h := c.Get(0, 1, nil); h != nil { + t.Errorf("Cache.Get on #0.1 return non-nil: %v", h.Value()) + } + + c.EvictNS(1) + if h := c.Get(1, 1, nil); h != nil { + t.Errorf("Cache.Get on #1.1 return non-nil: %v", h.Value()) + } + if h := c.Get(1, 2, nil); h != nil { + t.Errorf("Cache.Get on #1.2 return non-nil: %v", h.Value()) + } + + c.EvictAll() + for ns := 0; ns < 3; ns++ { + for key := 1; key < 3; key++ { + if h := c.Get(uint64(ns), uint64(key), nil); h != nil { + t.Errorf("Cache.Get on #%d.%d return non-nil: %v", ns, key, h.Value()) + } + } + } +} + +func TestLRUCache_Delete(t *testing.T) { + delFuncCalled := 0 + delFunc := func() { + delFuncCalled++ + } + + c := NewCache(NewLRU(2)) + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 1, nil).Release() + + if ok := c.Delete(0, 1, delFunc); !ok { + t.Error("Cache.Delete on #1 return false") + } + if h := c.Get(0, 1, nil); h != nil { + t.Errorf("Cache.Get on #1 return non-nil: %v", h.Value()) + } + if ok := c.Delete(0, 1, delFunc); ok { + t.Error("Cache.Delete on #1 return true") + } + + h2 := c.Get(0, 2, nil) + if h2 == nil { + t.Error("Cache.Get on #2 return nil") + } + if ok := c.Delete(0, 2, delFunc); !ok { + t.Error("(1) Cache.Delete on #2 return false") + } + if ok := c.Delete(0, 2, delFunc); !ok { + t.Error("(2) Cache.Delete on #2 return false") + } + + set(c, 0, 3, 3, 1, nil).Release() + set(c, 0, 4, 4, 1, nil).Release() + c.Get(0, 2, nil).Release() + + for key := 2; key <= 4; key++ { + if h := c.Get(0, uint64(key), nil); h != nil { + h.Release() + } else { + t.Errorf("Cache.Get on #%d return nil", key) + } + } + + h2.Release() + if h := c.Get(0, 2, nil); h != nil { + t.Errorf("Cache.Get on #2 return non-nil: %v", h.Value()) + } + + if delFuncCalled != 4 { + t.Errorf("delFunc isn't called 4 times: got=%d", delFuncCalled) + } +} + +func TestLRUCache_Close(t *testing.T) { + relFuncCalled := 0 + relFunc := func() { + relFuncCalled++ + } + delFuncCalled := 0 + delFunc := func() { + delFuncCalled++ + } + + c := NewCache(NewLRU(2)) + set(c, 0, 1, 1, 1, relFunc).Release() + set(c, 0, 2, 2, 1, relFunc).Release() + + h3 := set(c, 0, 3, 3, 1, relFunc) + if h3 == nil { + t.Error("Cache.Get on #3 return nil") + } + if ok := c.Delete(0, 3, delFunc); !ok { + t.Error("Cache.Delete on #3 return false") + } + + c.Close() + + if relFuncCalled != 3 { + t.Errorf("relFunc isn't called 3 times: got=%d", relFuncCalled) + } + if delFuncCalled != 1 { + t.Errorf("delFunc isn't called 1 times: got=%d", delFuncCalled) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/cache/lru.go b/vendor/github.com/syndtr/goleveldb/leveldb/cache/lru.go new file mode 100644 index 0000000..d9a84cd --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/cache/lru.go @@ -0,0 +1,195 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package cache + +import ( + "sync" + "unsafe" +) + +type lruNode struct { + n *Node + h *Handle + ban bool + + next, prev *lruNode +} + +func (n *lruNode) insert(at *lruNode) { + x := at.next + at.next = n + n.prev = at + n.next = x + x.prev = n +} + +func (n *lruNode) remove() { + if n.prev != nil { + n.prev.next = n.next + n.next.prev = n.prev + n.prev = nil + n.next = nil + } else { + panic("BUG: removing removed node") + } +} + +type lru struct { + mu sync.Mutex + capacity int + used int + recent lruNode +} + +func (r *lru) reset() { + r.recent.next = &r.recent + r.recent.prev = &r.recent + r.used = 0 +} + +func (r *lru) Capacity() int { + r.mu.Lock() + defer r.mu.Unlock() + return r.capacity +} + +func (r *lru) SetCapacity(capacity int) { + var evicted []*lruNode + + r.mu.Lock() + r.capacity = capacity + for r.used > r.capacity { + rn := r.recent.prev + if rn == nil { + panic("BUG: invalid LRU used or capacity counter") + } + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) Promote(n *Node) { + var evicted []*lruNode + + r.mu.Lock() + if n.CacheData == nil { + if n.Size() <= r.capacity { + rn := &lruNode{n: n, h: n.GetHandle()} + rn.insert(&r.recent) + n.CacheData = unsafe.Pointer(rn) + r.used += n.Size() + + for r.used > r.capacity { + rn := r.recent.prev + if rn == nil { + panic("BUG: invalid LRU used or capacity counter") + } + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + } + } else { + rn := (*lruNode)(n.CacheData) + if !rn.ban { + rn.remove() + rn.insert(&r.recent) + } + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) Ban(n *Node) { + r.mu.Lock() + if n.CacheData == nil { + n.CacheData = unsafe.Pointer(&lruNode{n: n, ban: true}) + } else { + rn := (*lruNode)(n.CacheData) + if !rn.ban { + rn.remove() + rn.ban = true + r.used -= rn.n.Size() + r.mu.Unlock() + + rn.h.Release() + rn.h = nil + return + } + } + r.mu.Unlock() +} + +func (r *lru) Evict(n *Node) { + r.mu.Lock() + rn := (*lruNode)(n.CacheData) + if rn == nil || rn.ban { + r.mu.Unlock() + return + } + n.CacheData = nil + r.mu.Unlock() + + rn.h.Release() +} + +func (r *lru) EvictNS(ns uint64) { + var evicted []*lruNode + + r.mu.Lock() + for e := r.recent.prev; e != &r.recent; { + rn := e + e = e.prev + if rn.n.NS() == ns { + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) EvictAll() { + r.mu.Lock() + back := r.recent.prev + for rn := back; rn != &r.recent; rn = rn.prev { + rn.n.CacheData = nil + } + r.reset() + r.mu.Unlock() + + for rn := back; rn != &r.recent; rn = rn.prev { + rn.h.Release() + } +} + +func (r *lru) Close() error { + return nil +} + +// NewLRU create a new LRU-cache. +func NewLRU(capacity int) Cacher { + r := &lru{capacity: capacity} + r.reset() + return r +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/comparer.go b/vendor/github.com/syndtr/goleveldb/leveldb/comparer.go new file mode 100644 index 0000000..448402b --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/comparer.go @@ -0,0 +1,67 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +type iComparer struct { + ucmp comparer.Comparer +} + +func (icmp *iComparer) uName() string { + return icmp.ucmp.Name() +} + +func (icmp *iComparer) uCompare(a, b []byte) int { + return icmp.ucmp.Compare(a, b) +} + +func (icmp *iComparer) uSeparator(dst, a, b []byte) []byte { + return icmp.ucmp.Separator(dst, a, b) +} + +func (icmp *iComparer) uSuccessor(dst, b []byte) []byte { + return icmp.ucmp.Successor(dst, b) +} + +func (icmp *iComparer) Name() string { + return icmp.uName() +} + +func (icmp *iComparer) Compare(a, b []byte) int { + x := icmp.uCompare(internalKey(a).ukey(), internalKey(b).ukey()) + if x == 0 { + if m, n := internalKey(a).num(), internalKey(b).num(); m > n { + return -1 + } else if m < n { + return 1 + } + } + return x +} + +func (icmp *iComparer) Separator(dst, a, b []byte) []byte { + ua, ub := internalKey(a).ukey(), internalKey(b).ukey() + dst = icmp.uSeparator(dst, ua, ub) + if dst != nil && len(dst) < len(ua) && icmp.uCompare(ua, dst) < 0 { + // Append earliest possible number. + return append(dst, keyMaxNumBytes...) + } + return nil +} + +func (icmp *iComparer) Successor(dst, b []byte) []byte { + ub := internalKey(b).ukey() + dst = icmp.uSuccessor(dst, ub) + if dst != nil && len(dst) < len(ub) && icmp.uCompare(ub, dst) < 0 { + // Append earliest possible number. + return append(dst, keyMaxNumBytes...) + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go b/vendor/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go new file mode 100644 index 0000000..abf9fb6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go @@ -0,0 +1,51 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package comparer + +import "bytes" + +type bytesComparer struct{} + +func (bytesComparer) Compare(a, b []byte) int { + return bytes.Compare(a, b) +} + +func (bytesComparer) Name() string { + return "leveldb.BytewiseComparator" +} + +func (bytesComparer) Separator(dst, a, b []byte) []byte { + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for ; i < n && a[i] == b[i]; i++ { + } + if i >= n { + // Do not shorten if one string is a prefix of the other + } else if c := a[i]; c < 0xff && c+1 < b[i] { + dst = append(dst, a[:i+1]...) + dst[len(dst)-1]++ + return dst + } + return nil +} + +func (bytesComparer) Successor(dst, b []byte) []byte { + for i, c := range b { + if c != 0xff { + dst = append(dst, b[:i+1]...) + dst[len(dst)-1]++ + return dst + } + } + return nil +} + +// DefaultComparer are default implementation of the Comparer interface. +// It uses the natural ordering, consistent with bytes.Compare. +var DefaultComparer = bytesComparer{} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go b/vendor/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go new file mode 100644 index 0000000..2c522db --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go @@ -0,0 +1,57 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package comparer provides interface and implementation for ordering +// sets of data. +package comparer + +// BasicComparer is the interface that wraps the basic Compare method. +type BasicComparer interface { + // Compare returns -1, 0, or +1 depending on whether a is 'less than', + // 'equal to' or 'greater than' b. The two arguments can only be 'equal' + // if their contents are exactly equal. Furthermore, the empty slice + // must be 'less than' any non-empty slice. + Compare(a, b []byte) int +} + +// Comparer defines a total ordering over the space of []byte keys: a 'less +// than' relationship. +type Comparer interface { + BasicComparer + + // Name returns name of the comparer. + // + // The Level-DB on-disk format stores the comparer name, and opening a + // database with a different comparer from the one it was created with + // will result in an error. + // + // An implementation to a new name whenever the comparer implementation + // changes in a way that will cause the relative ordering of any two keys + // to change. + // + // Names starting with "leveldb." are reserved and should not be used + // by any users of this package. + Name() string + + // Bellow are advanced functions used to reduce the space requirements + // for internal data structures such as index blocks. + + // Separator appends a sequence of bytes x to dst such that a <= x && x < b, + // where 'less than' is consistent with Compare. An implementation should + // return nil if x equal to a. + // + // Either contents of a or b should not by any means modified. Doing so + // may cause corruption on the internal state. + Separator(dst, a, b []byte) []byte + + // Successor appends a sequence of bytes x to dst such that x >= b, where + // 'less than' is consistent with Compare. An implementation should return + // nil if x equal to b. + // + // Contents of b should not by any means modified. Doing so may cause + // corruption on the internal state. + Successor(dst, b []byte) []byte +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/corrupt_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/corrupt_test.go new file mode 100644 index 0000000..a987b27 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/corrupt_test.go @@ -0,0 +1,498 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "testing" + "time" + + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +const ctValSize = 1000 + +type dbCorruptHarness struct { + dbHarness +} + +func newDbCorruptHarnessWopt(t *testing.T, o *opt.Options) *dbCorruptHarness { + h := new(dbCorruptHarness) + h.init(t, o) + return h +} + +func newDbCorruptHarness(t *testing.T) *dbCorruptHarness { + return newDbCorruptHarnessWopt(t, &opt.Options{ + BlockCacheCapacity: 100, + Strict: opt.StrictJournalChecksum, + }) +} + +func (h *dbCorruptHarness) recover() { + p := &h.dbHarness + t := p.t + + var err error + p.db, err = Recover(h.stor, h.o) + if err != nil { + t.Fatal("Repair: got error: ", err) + } +} + +func (h *dbCorruptHarness) build(n int) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := 0; i < n; i++ { + batch.Reset() + batch.Put(tkey(i), tval(i, ctValSize)) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) buildShuffled(n int, rnd *rand.Rand) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := range rnd.Perm(n) { + batch.Reset() + batch.Put(tkey(i), tval(i, ctValSize)) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) deleteRand(n, max int, rnd *rand.Rand) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := 0; i < n; i++ { + batch.Reset() + batch.Delete(tkey(rnd.Intn(max))) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) corrupt(ft storage.FileType, fi, offset, n int) { + p := &h.dbHarness + t := p.t + + fds, _ := p.stor.List(ft) + sortFds(fds) + if fi < 0 { + fi = len(fds) - 1 + } + if fi >= len(fds) { + t.Fatalf("no such file with type %q with index %d", ft, fi) + } + + fd := fds[fi] + r, err := h.stor.Open(fd) + if err != nil { + t.Fatal("cannot open file: ", err) + } + x, err := r.Seek(0, 2) + if err != nil { + t.Fatal("cannot query file size: ", err) + } + m := int(x) + if _, err := r.Seek(0, 0); err != nil { + t.Fatal(err) + } + + if offset < 0 { + if -offset > m { + offset = 0 + } else { + offset = m + offset + } + } + if offset > m { + offset = m + } + if offset+n > m { + n = m - offset + } + + buf := make([]byte, m) + _, err = io.ReadFull(r, buf) + if err != nil { + t.Fatal("cannot read file: ", err) + } + r.Close() + + for i := 0; i < n; i++ { + buf[offset+i] ^= 0x80 + } + + err = h.stor.Remove(fd) + if err != nil { + t.Fatal("cannot remove old file: ", err) + } + w, err := h.stor.Create(fd) + if err != nil { + t.Fatal("cannot create new file: ", err) + } + _, err = w.Write(buf) + if err != nil { + t.Fatal("cannot write new file: ", err) + } + w.Close() +} + +func (h *dbCorruptHarness) removeAll(ft storage.FileType) { + fds, err := h.stor.List(ft) + if err != nil { + h.t.Fatal("get files: ", err) + } + for _, fd := range fds { + if err := h.stor.Remove(fd); err != nil { + h.t.Error("remove file: ", err) + } + } +} + +func (h *dbCorruptHarness) forceRemoveAll(ft storage.FileType) { + fds, err := h.stor.List(ft) + if err != nil { + h.t.Fatal("get files: ", err) + } + for _, fd := range fds { + if err := h.stor.ForceRemove(fd); err != nil { + h.t.Error("remove file: ", err) + } + } +} + +func (h *dbCorruptHarness) removeOne(ft storage.FileType) { + fds, err := h.stor.List(ft) + if err != nil { + h.t.Fatal("get files: ", err) + } + fd := fds[rand.Intn(len(fds))] + h.t.Logf("removing file @%d", fd.Num) + if err := h.stor.Remove(fd); err != nil { + h.t.Error("remove file: ", err) + } +} + +func (h *dbCorruptHarness) check(min, max int) { + p := &h.dbHarness + t := p.t + db := p.db + + var n, badk, badv, missed, good int + iter := db.NewIterator(nil, p.ro) + for iter.Next() { + k := 0 + fmt.Sscanf(string(iter.Key()), "%d", &k) + if k < n { + badk++ + continue + } + missed += k - n + n = k + 1 + if !bytes.Equal(iter.Value(), tval(k, ctValSize)) { + badv++ + } else { + good++ + } + } + err := iter.Error() + iter.Release() + t.Logf("want=%d..%d got=%d badkeys=%d badvalues=%d missed=%d, err=%v", + min, max, good, badk, badv, missed, err) + if good < min || good > max { + t.Errorf("good entries number not in range") + } +} + +func TestCorruptDB_Journal(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.build(100) + h.check(100, 100) + h.closeDB() + h.corrupt(storage.TypeJournal, -1, 19, 1) + h.corrupt(storage.TypeJournal, -1, 32*1024+1000, 1) + + h.openDB() + h.check(36, 36) +} + +func TestCorruptDB_Table(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.build(100) + h.compactMem() + h.compactRangeAt(0, "", "") + h.compactRangeAt(1, "", "") + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.check(99, 99) +} + +func TestCorruptDB_TableIndex(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.build(10000) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, -2000, 500) + + h.openDB() + h.check(5000, 9999) +} + +func TestCorruptDB_MissingManifest(t *testing.T) { + rnd := rand.New(rand.NewSource(0x0badda7a)) + h := newDbCorruptHarnessWopt(t, &opt.Options{ + BlockCacheCapacity: 100, + Strict: opt.StrictJournalChecksum, + WriteBuffer: 1000 * 60, + }) + defer h.close() + + h.build(1000) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.deleteRand(500, 1000, rnd) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.deleteRand(500, 1000, rnd) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.closeDB() + + h.forceRemoveAll(storage.TypeManifest) + h.openAssert(false) + + h.recover() + h.check(1000, 1000) + h.build(1000) + h.compactMem() + h.compactRange("", "") + h.closeDB() + + h.recover() + h.check(1000, 1000) +} + +func TestCorruptDB_SequenceNumberRecovery(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("foo", "v1") + h.put("foo", "v2") + h.put("foo", "v3") + h.put("foo", "v4") + h.put("foo", "v5") + h.closeDB() + + h.recover() + h.getVal("foo", "v5") + h.put("foo", "v6") + h.getVal("foo", "v6") + + h.reopenDB() + h.getVal("foo", "v6") +} + +func TestCorruptDB_SequenceNumberRecoveryTable(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("foo", "v1") + h.put("foo", "v2") + h.put("foo", "v3") + h.compactMem() + h.put("foo", "v4") + h.put("foo", "v5") + h.compactMem() + h.closeDB() + + h.recover() + h.getVal("foo", "v5") + h.put("foo", "v6") + h.getVal("foo", "v6") + + h.reopenDB() + h.getVal("foo", "v6") +} + +func TestCorruptDB_CorruptedManifest(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("foo", "hello") + h.compactMem() + h.compactRange("", "") + h.closeDB() + h.corrupt(storage.TypeManifest, -1, 0, 1000) + h.openAssert(false) + + h.recover() + h.getVal("foo", "hello") +} + +func TestCorruptDB_CompactionInputError(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.build(10) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.check(9, 9) + + h.build(10000) + h.check(10000, 10000) +} + +func TestCorruptDB_UnrelatedKeys(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.build(10) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.put(string(tkey(1000)), string(tval(1000, ctValSize))) + h.getVal(string(tkey(1000)), string(tval(1000, ctValSize))) + h.compactMem() + h.getVal(string(tkey(1000)), string(tval(1000, ctValSize))) +} + +func TestCorruptDB_Level0NewerFileHasOlderSeqnum(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("a", "v2") + h.put("b", "v2") + h.compactMem() + h.put("a", "v3") + h.put("b", "v3") + h.compactMem() + h.put("c", "v0") + h.put("d", "v0") + h.compactMem() + h.compactRangeAt(1, "", "") + h.closeDB() + + h.recover() + h.getVal("a", "v3") + h.getVal("b", "v3") + h.getVal("c", "v0") + h.getVal("d", "v0") +} + +func TestCorruptDB_RecoverInvalidSeq_Issue53(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("a", "v2") + h.put("b", "v2") + h.compactMem() + h.put("a", "v3") + h.put("b", "v3") + h.compactMem() + h.put("c", "v0") + h.put("d", "v0") + h.compactMem() + h.compactRangeAt(0, "", "") + h.closeDB() + + h.recover() + h.getVal("a", "v3") + h.getVal("b", "v3") + h.getVal("c", "v0") + h.getVal("d", "v0") +} + +func TestCorruptDB_MissingTableFiles(t *testing.T) { + h := newDbCorruptHarness(t) + defer h.close() + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("c", "v2") + h.put("d", "v2") + h.compactMem() + h.put("e", "v3") + h.put("f", "v3") + h.closeDB() + + h.removeOne(storage.TypeTable) + h.openAssert(false) +} + +func TestCorruptDB_RecoverTable(t *testing.T) { + h := newDbCorruptHarnessWopt(t, &opt.Options{ + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 90 * opt.KiB, + Filter: filter.NewBloomFilter(10), + }) + defer h.close() + + h.build(1000) + h.compactMem() + h.compactRangeAt(0, "", "") + h.compactRangeAt(1, "", "") + seq := h.db.seq + time.Sleep(100 * time.Millisecond) // Wait lazy reference finish tasks + h.closeDB() + h.corrupt(storage.TypeTable, 0, 1000, 1) + h.corrupt(storage.TypeTable, 3, 10000, 1) + // Corrupted filter shouldn't affect recovery. + h.corrupt(storage.TypeTable, 3, 113888, 10) + h.corrupt(storage.TypeTable, -1, 20000, 1) + + h.recover() + if h.db.seq != seq { + t.Errorf("invalid seq, want=%d got=%d", seq, h.db.seq) + } + h.check(985, 985) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db.go b/vendor/github.com/syndtr/goleveldb/leveldb/db.go new file mode 100644 index 0000000..74e9826 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db.go @@ -0,0 +1,1205 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "container/list" + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/table" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// DB is a LevelDB database. +type DB struct { + // Need 64-bit alignment. + seq uint64 + + // Stats. Need 64-bit alignment. + cWriteDelay int64 // The cumulative duration of write delays + cWriteDelayN int32 // The cumulative number of write delays + inWritePaused int32 // The indicator whether write operation is paused by compaction + aliveSnaps, aliveIters int32 + + // Compaction statistic + memComp uint32 // The cumulative number of memory compaction + level0Comp uint32 // The cumulative number of level0 compaction + nonLevel0Comp uint32 // The cumulative number of non-level0 compaction + seekComp uint32 // The cumulative number of seek compaction + + // Session. + s *session + + // MemDB. + memMu sync.RWMutex + memPool chan *memdb.DB + mem, frozenMem *memDB + journal *journal.Writer + journalWriter storage.Writer + journalFd storage.FileDesc + frozenJournalFd storage.FileDesc + frozenSeq uint64 + + // Snapshot. + snapsMu sync.Mutex + snapsList *list.List + + // Write. + batchPool sync.Pool + writeMergeC chan writeMerge + writeMergedC chan bool + writeLockC chan struct{} + writeAckC chan error + writeDelay time.Duration + writeDelayN int + tr *Transaction + + // Compaction. + compCommitLk sync.Mutex + tcompCmdC chan cCmd + tcompPauseC chan chan<- struct{} + mcompCmdC chan cCmd + compErrC chan error + compPerErrC chan error + compErrSetC chan error + compWriteLocking bool + compStats cStats + memdbMaxLevel int // For testing. + + // Close. + closeW sync.WaitGroup + closeC chan struct{} + closed uint32 + closer io.Closer +} + +func openDB(s *session) (*DB, error) { + s.log("db@open opening") + start := time.Now() + db := &DB{ + s: s, + // Initial sequence + seq: s.stSeqNum, + // MemDB + memPool: make(chan *memdb.DB, 1), + // Snapshot + snapsList: list.New(), + // Write + batchPool: sync.Pool{New: newBatch}, + writeMergeC: make(chan writeMerge), + writeMergedC: make(chan bool), + writeLockC: make(chan struct{}, 1), + writeAckC: make(chan error), + // Compaction + tcompCmdC: make(chan cCmd), + tcompPauseC: make(chan chan<- struct{}), + mcompCmdC: make(chan cCmd), + compErrC: make(chan error), + compPerErrC: make(chan error), + compErrSetC: make(chan error), + // Close + closeC: make(chan struct{}), + } + + // Read-only mode. + readOnly := s.o.GetReadOnly() + + if readOnly { + // Recover journals (read-only mode). + if err := db.recoverJournalRO(); err != nil { + return nil, err + } + } else { + // Recover journals. + if err := db.recoverJournal(); err != nil { + return nil, err + } + + // Remove any obsolete files. + if err := db.checkAndCleanFiles(); err != nil { + // Close journal. + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + } + return nil, err + } + + } + + // Doesn't need to be included in the wait group. + go db.compactionError() + go db.mpoolDrain() + + if readOnly { + db.SetReadOnly() + } else { + db.closeW.Add(2) + go db.tCompaction() + go db.mCompaction() + // go db.jWriter() + } + + s.logf("db@open done T·%v", time.Since(start)) + + runtime.SetFinalizer(db, (*DB).Close) + return db, nil +} + +// Open opens or creates a DB for the given storage. +// The DB will be created if not exist, unless ErrorIfMissing is true. +// Also, if ErrorIfExist is true and the DB exist Open will returns +// os.ErrExist error. +// +// Open will return an error with type of ErrCorrupted if corruption +// detected in the DB. Use errors.IsCorrupted to test whether an error is +// due to corruption. Corrupted DB can be recovered with Recover function. +// +// The returned DB instance is safe for concurrent use. +// The DB must be closed after use, by calling Close method. +func Open(stor storage.Storage, o *opt.Options) (db *DB, err error) { + s, err := newSession(stor, o) + if err != nil { + return + } + defer func() { + if err != nil { + s.close() + s.release() + } + }() + + err = s.recover() + if err != nil { + if !os.IsNotExist(err) || s.o.GetErrorIfMissing() || s.o.GetReadOnly() { + return + } + err = s.create() + if err != nil { + return + } + } else if s.o.GetErrorIfExist() { + err = os.ErrExist + return + } + + return openDB(s) +} + +// OpenFile opens or creates a DB for the given path. +// The DB will be created if not exist, unless ErrorIfMissing is true. +// Also, if ErrorIfExist is true and the DB exist OpenFile will returns +// os.ErrExist error. +// +// OpenFile uses standard file-system backed storage implementation as +// described in the leveldb/storage package. +// +// OpenFile will return an error with type of ErrCorrupted if corruption +// detected in the DB. Use errors.IsCorrupted to test whether an error is +// due to corruption. Corrupted DB can be recovered with Recover function. +// +// The returned DB instance is safe for concurrent use. +// The DB must be closed after use, by calling Close method. +func OpenFile(path string, o *opt.Options) (db *DB, err error) { + stor, err := storage.OpenFile(path, o.GetReadOnly()) + if err != nil { + return + } + db, err = Open(stor, o) + if err != nil { + stor.Close() + } else { + db.closer = stor + } + return +} + +// Recover recovers and opens a DB with missing or corrupted manifest files +// for the given storage. It will ignore any manifest files, valid or not. +// The DB must already exist or it will returns an error. +// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options. +// +// The returned DB instance is safe for concurrent use. +// The DB must be closed after use, by calling Close method. +func Recover(stor storage.Storage, o *opt.Options) (db *DB, err error) { + s, err := newSession(stor, o) + if err != nil { + return + } + defer func() { + if err != nil { + s.close() + s.release() + } + }() + + err = recoverTable(s, o) + if err != nil { + return + } + return openDB(s) +} + +// RecoverFile recovers and opens a DB with missing or corrupted manifest files +// for the given path. It will ignore any manifest files, valid or not. +// The DB must already exist or it will returns an error. +// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options. +// +// RecoverFile uses standard file-system backed storage implementation as described +// in the leveldb/storage package. +// +// The returned DB instance is safe for concurrent use. +// The DB must be closed after use, by calling Close method. +func RecoverFile(path string, o *opt.Options) (db *DB, err error) { + stor, err := storage.OpenFile(path, false) + if err != nil { + return + } + db, err = Recover(stor, o) + if err != nil { + stor.Close() + } else { + db.closer = stor + } + return +} + +func recoverTable(s *session, o *opt.Options) error { + o = dupOptions(o) + // Mask StrictReader, lets StrictRecovery doing its job. + o.Strict &= ^opt.StrictReader + + // Get all tables and sort it by file number. + fds, err := s.stor.List(storage.TypeTable) + if err != nil { + return err + } + sortFds(fds) + + var ( + maxSeq uint64 + recoveredKey, goodKey, corruptedKey, corruptedBlock, droppedTable int + + // We will drop corrupted table. + strict = o.GetStrict(opt.StrictRecovery) + noSync = o.GetNoSync() + + rec = &sessionRecord{} + bpool = util.NewBufferPool(o.GetBlockSize() + 5) + ) + buildTable := func(iter iterator.Iterator) (tmpFd storage.FileDesc, size int64, err error) { + tmpFd = s.newTemp() + writer, err := s.stor.Create(tmpFd) + if err != nil { + return + } + defer func() { + writer.Close() + if err != nil { + s.stor.Remove(tmpFd) + tmpFd = storage.FileDesc{} + } + }() + + // Copy entries. + tw := table.NewWriter(writer, o) + for iter.Next() { + key := iter.Key() + if validInternalKey(key) { + err = tw.Append(key, iter.Value()) + if err != nil { + return + } + } + } + err = iter.Error() + if err != nil && !errors.IsCorrupted(err) { + return + } + err = tw.Close() + if err != nil { + return + } + if !noSync { + err = writer.Sync() + if err != nil { + return + } + } + size = int64(tw.BytesLen()) + return + } + recoverTable := func(fd storage.FileDesc) error { + s.logf("table@recovery recovering @%d", fd.Num) + reader, err := s.stor.Open(fd) + if err != nil { + return err + } + var closed bool + defer func() { + if !closed { + reader.Close() + } + }() + + // Get file size. + size, err := reader.Seek(0, 2) + if err != nil { + return err + } + + var ( + tSeq uint64 + tgoodKey, tcorruptedKey, tcorruptedBlock int + imin, imax []byte + ) + tr, err := table.NewReader(reader, size, fd, nil, bpool, o) + if err != nil { + return err + } + iter := tr.NewIterator(nil, nil) + if itererr, ok := iter.(iterator.ErrorCallbackSetter); ok { + itererr.SetErrorCallback(func(err error) { + if errors.IsCorrupted(err) { + s.logf("table@recovery block corruption @%d %q", fd.Num, err) + tcorruptedBlock++ + } + }) + } + + // Scan the table. + for iter.Next() { + key := iter.Key() + _, seq, _, kerr := parseInternalKey(key) + if kerr != nil { + tcorruptedKey++ + continue + } + tgoodKey++ + if seq > tSeq { + tSeq = seq + } + if imin == nil { + imin = append([]byte{}, key...) + } + imax = append(imax[:0], key...) + } + if err := iter.Error(); err != nil && !errors.IsCorrupted(err) { + iter.Release() + return err + } + iter.Release() + + goodKey += tgoodKey + corruptedKey += tcorruptedKey + corruptedBlock += tcorruptedBlock + + if strict && (tcorruptedKey > 0 || tcorruptedBlock > 0) { + droppedTable++ + s.logf("table@recovery dropped @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", fd.Num, tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq) + return nil + } + + if tgoodKey > 0 { + if tcorruptedKey > 0 || tcorruptedBlock > 0 { + // Rebuild the table. + s.logf("table@recovery rebuilding @%d", fd.Num) + iter := tr.NewIterator(nil, nil) + tmpFd, newSize, err := buildTable(iter) + iter.Release() + if err != nil { + return err + } + closed = true + reader.Close() + if err := s.stor.Rename(tmpFd, fd); err != nil { + return err + } + size = newSize + } + if tSeq > maxSeq { + maxSeq = tSeq + } + recoveredKey += tgoodKey + // Add table to level 0. + rec.addTable(0, fd.Num, size, imin, imax) + s.logf("table@recovery recovered @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", fd.Num, tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq) + } else { + droppedTable++ + s.logf("table@recovery unrecoverable @%d Ck·%d Cb·%d S·%d", fd.Num, tcorruptedKey, tcorruptedBlock, size) + } + + return nil + } + + // Recover all tables. + if len(fds) > 0 { + s.logf("table@recovery F·%d", len(fds)) + + // Mark file number as used. + s.markFileNum(fds[len(fds)-1].Num) + + for _, fd := range fds { + if err := recoverTable(fd); err != nil { + return err + } + } + + s.logf("table@recovery recovered F·%d N·%d Gk·%d Ck·%d Q·%d", len(fds), recoveredKey, goodKey, corruptedKey, maxSeq) + } + + // Set sequence number. + rec.setSeqNum(maxSeq) + + // Create new manifest. + if err := s.create(); err != nil { + return err + } + + // Commit. + return s.commit(rec, false) +} + +func (db *DB) recoverJournal() error { + // Get all journals and sort it by file number. + rawFds, err := db.s.stor.List(storage.TypeJournal) + if err != nil { + return err + } + sortFds(rawFds) + + // Journals that will be recovered. + var fds []storage.FileDesc + for _, fd := range rawFds { + if fd.Num >= db.s.stJournalNum || fd.Num == db.s.stPrevJournalNum { + fds = append(fds, fd) + } + } + + var ( + ofd storage.FileDesc // Obsolete file. + rec = &sessionRecord{} + ) + + // Recover journals. + if len(fds) > 0 { + db.logf("journal@recovery F·%d", len(fds)) + + // Mark file number as used. + db.s.markFileNum(fds[len(fds)-1].Num) + + var ( + // Options. + strict = db.s.o.GetStrict(opt.StrictJournal) + checksum = db.s.o.GetStrict(opt.StrictJournalChecksum) + writeBuffer = db.s.o.GetWriteBuffer() + + jr *journal.Reader + mdb = memdb.New(db.s.icmp, writeBuffer) + buf = &util.Buffer{} + batchSeq uint64 + batchLen int + ) + + for _, fd := range fds { + db.logf("journal@recovery recovering @%d", fd.Num) + + fr, err := db.s.stor.Open(fd) + if err != nil { + return err + } + + // Create or reset journal reader instance. + if jr == nil { + jr = journal.NewReader(fr, dropper{db.s, fd}, strict, checksum) + } else { + jr.Reset(fr, dropper{db.s, fd}, strict, checksum) + } + + // Flush memdb and remove obsolete journal file. + if !ofd.Zero() { + if mdb.Len() > 0 { + if _, err := db.s.flushMemdb(rec, mdb, 0); err != nil { + fr.Close() + return err + } + } + + rec.setJournalNum(fd.Num) + rec.setSeqNum(db.seq) + if err := db.s.commit(rec, false); err != nil { + fr.Close() + return err + } + rec.resetAddedTables() + + db.s.stor.Remove(ofd) + ofd = storage.FileDesc{} + } + + // Replay journal to memdb. + mdb.Reset() + for { + r, err := jr.Next() + if err != nil { + if err == io.EOF { + break + } + + fr.Close() + return errors.SetFd(err, fd) + } + + buf.Reset() + if _, err := buf.ReadFrom(r); err != nil { + if err == io.ErrUnexpectedEOF { + // This is error returned due to corruption, with strict == false. + continue + } + + fr.Close() + return errors.SetFd(err, fd) + } + batchSeq, batchLen, err = decodeBatchToMem(buf.Bytes(), db.seq, mdb) + if err != nil { + if !strict && errors.IsCorrupted(err) { + db.s.logf("journal error: %v (skipped)", err) + // We won't apply sequence number as it might be corrupted. + continue + } + + fr.Close() + return errors.SetFd(err, fd) + } + + // Save sequence number. + db.seq = batchSeq + uint64(batchLen) + + // Flush it if large enough. + if mdb.Size() >= writeBuffer { + if _, err := db.s.flushMemdb(rec, mdb, 0); err != nil { + fr.Close() + return err + } + + mdb.Reset() + } + } + + fr.Close() + ofd = fd + } + + // Flush the last memdb. + if mdb.Len() > 0 { + if _, err := db.s.flushMemdb(rec, mdb, 0); err != nil { + return err + } + } + } + + // Create a new journal. + if _, err := db.newMem(0); err != nil { + return err + } + + // Commit. + rec.setJournalNum(db.journalFd.Num) + rec.setSeqNum(db.seq) + if err := db.s.commit(rec, false); err != nil { + // Close journal on error. + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + } + return err + } + + // Remove the last obsolete journal file. + if !ofd.Zero() { + db.s.stor.Remove(ofd) + } + + return nil +} + +func (db *DB) recoverJournalRO() error { + // Get all journals and sort it by file number. + rawFds, err := db.s.stor.List(storage.TypeJournal) + if err != nil { + return err + } + sortFds(rawFds) + + // Journals that will be recovered. + var fds []storage.FileDesc + for _, fd := range rawFds { + if fd.Num >= db.s.stJournalNum || fd.Num == db.s.stPrevJournalNum { + fds = append(fds, fd) + } + } + + var ( + // Options. + strict = db.s.o.GetStrict(opt.StrictJournal) + checksum = db.s.o.GetStrict(opt.StrictJournalChecksum) + writeBuffer = db.s.o.GetWriteBuffer() + + mdb = memdb.New(db.s.icmp, writeBuffer) + ) + + // Recover journals. + if len(fds) > 0 { + db.logf("journal@recovery RO·Mode F·%d", len(fds)) + + var ( + jr *journal.Reader + buf = &util.Buffer{} + batchSeq uint64 + batchLen int + ) + + for _, fd := range fds { + db.logf("journal@recovery recovering @%d", fd.Num) + + fr, err := db.s.stor.Open(fd) + if err != nil { + return err + } + + // Create or reset journal reader instance. + if jr == nil { + jr = journal.NewReader(fr, dropper{db.s, fd}, strict, checksum) + } else { + jr.Reset(fr, dropper{db.s, fd}, strict, checksum) + } + + // Replay journal to memdb. + for { + r, err := jr.Next() + if err != nil { + if err == io.EOF { + break + } + + fr.Close() + return errors.SetFd(err, fd) + } + + buf.Reset() + if _, err := buf.ReadFrom(r); err != nil { + if err == io.ErrUnexpectedEOF { + // This is error returned due to corruption, with strict == false. + continue + } + + fr.Close() + return errors.SetFd(err, fd) + } + batchSeq, batchLen, err = decodeBatchToMem(buf.Bytes(), db.seq, mdb) + if err != nil { + if !strict && errors.IsCorrupted(err) { + db.s.logf("journal error: %v (skipped)", err) + // We won't apply sequence number as it might be corrupted. + continue + } + + fr.Close() + return errors.SetFd(err, fd) + } + + // Save sequence number. + db.seq = batchSeq + uint64(batchLen) + } + + fr.Close() + } + } + + // Set memDB. + db.mem = &memDB{db: db, DB: mdb, ref: 1} + + return nil +} + +func memGet(mdb *memdb.DB, ikey internalKey, icmp *iComparer) (ok bool, mv []byte, err error) { + mk, mv, err := mdb.Find(ikey) + if err == nil { + ukey, _, kt, kerr := parseInternalKey(mk) + if kerr != nil { + // Shouldn't have had happen. + panic(kerr) + } + if icmp.uCompare(ukey, ikey.ukey()) == 0 { + if kt == keyTypeDel { + return true, nil, ErrNotFound + } + return true, mv, nil + + } + } else if err != ErrNotFound { + return true, nil, err + } + return +} + +func (db *DB) get(auxm *memdb.DB, auxt tFiles, key []byte, seq uint64, ro *opt.ReadOptions) (value []byte, err error) { + ikey := makeInternalKey(nil, key, seq, keyTypeSeek) + + if auxm != nil { + if ok, mv, me := memGet(auxm, ikey, db.s.icmp); ok { + return append([]byte{}, mv...), me + } + } + + em, fm := db.getMems() + for _, m := range [...]*memDB{em, fm} { + if m == nil { + continue + } + defer m.decref() + + if ok, mv, me := memGet(m.DB, ikey, db.s.icmp); ok { + return append([]byte{}, mv...), me + } + } + + v := db.s.version() + value, cSched, err := v.get(auxt, ikey, ro, false) + v.release() + if cSched { + // Trigger table compaction. + db.compTrigger(db.tcompCmdC) + } + return +} + +func nilIfNotFound(err error) error { + if err == ErrNotFound { + return nil + } + return err +} + +func (db *DB) has(auxm *memdb.DB, auxt tFiles, key []byte, seq uint64, ro *opt.ReadOptions) (ret bool, err error) { + ikey := makeInternalKey(nil, key, seq, keyTypeSeek) + + if auxm != nil { + if ok, _, me := memGet(auxm, ikey, db.s.icmp); ok { + return me == nil, nilIfNotFound(me) + } + } + + em, fm := db.getMems() + for _, m := range [...]*memDB{em, fm} { + if m == nil { + continue + } + defer m.decref() + + if ok, _, me := memGet(m.DB, ikey, db.s.icmp); ok { + return me == nil, nilIfNotFound(me) + } + } + + v := db.s.version() + _, cSched, err := v.get(auxt, ikey, ro, true) + v.release() + if cSched { + // Trigger table compaction. + db.compTrigger(db.tcompCmdC) + } + if err == nil { + ret = true + } else if err == ErrNotFound { + err = nil + } + return +} + +// Get gets the value for the given key. It returns ErrNotFound if the +// DB does not contains the key. +// +// The returned slice is its own copy, it is safe to modify the contents +// of the returned slice. +// It is safe to modify the contents of the argument after Get returns. +func (db *DB) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + err = db.ok() + if err != nil { + return + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + return db.get(nil, nil, key, se.seq, ro) +} + +// Has returns true if the DB does contains the given key. +// +// It is safe to modify the contents of the argument after Has returns. +func (db *DB) Has(key []byte, ro *opt.ReadOptions) (ret bool, err error) { + err = db.ok() + if err != nil { + return + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + return db.has(nil, nil, key, se.seq, ro) +} + +// NewIterator returns an iterator for the latest snapshot of the +// underlying DB. +// The returned iterator is not safe for concurrent use, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. The resultant key/value pairs are guaranteed to be +// consistent. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// +// The iterator must be released after use, by calling Release method. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (db *DB) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + if err := db.ok(); err != nil { + return iterator.NewEmptyIterator(err) + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + // Iterator holds 'version' lock, 'version' is immutable so snapshot + // can be released after iterator created. + return db.newIterator(nil, nil, se.seq, slice, ro) +} + +// GetSnapshot returns a latest snapshot of the underlying DB. A snapshot +// is a frozen snapshot of a DB state at a particular point in time. The +// content of snapshot are guaranteed to be consistent. +// +// The snapshot must be released after use, by calling Release method. +func (db *DB) GetSnapshot() (*Snapshot, error) { + if err := db.ok(); err != nil { + return nil, err + } + + return db.newSnapshot(), nil +} + +// GetProperty returns value of the given property name. +// +// Property names: +// leveldb.num-files-at-level{n} +// Returns the number of files at level 'n'. +// leveldb.stats +// Returns statistics of the underlying DB. +// leveldb.iostats +// Returns statistics of effective disk read and write. +// leveldb.writedelay +// Returns cumulative write delay caused by compaction. +// leveldb.sstables +// Returns sstables list for each level. +// leveldb.blockpool +// Returns block pool stats. +// leveldb.cachedblock +// Returns size of cached block. +// leveldb.openedtables +// Returns number of opened tables. +// leveldb.alivesnaps +// Returns number of alive snapshots. +// leveldb.aliveiters +// Returns number of alive iterators. +func (db *DB) GetProperty(name string) (value string, err error) { + err = db.ok() + if err != nil { + return + } + + const prefix = "leveldb." + if !strings.HasPrefix(name, prefix) { + return "", ErrNotFound + } + p := name[len(prefix):] + + v := db.s.version() + defer v.release() + + numFilesPrefix := "num-files-at-level" + switch { + case strings.HasPrefix(p, numFilesPrefix): + var level uint + var rest string + n, _ := fmt.Sscanf(p[len(numFilesPrefix):], "%d%s", &level, &rest) + if n != 1 { + err = ErrNotFound + } else { + value = fmt.Sprint(v.tLen(int(level))) + } + case p == "stats": + value = "Compactions\n" + + " Level | Tables | Size(MB) | Time(sec) | Read(MB) | Write(MB)\n" + + "-------+------------+---------------+---------------+---------------+---------------\n" + var totalTables int + var totalSize, totalRead, totalWrite int64 + var totalDuration time.Duration + for level, tables := range v.levels { + duration, read, write := db.compStats.getStat(level) + if len(tables) == 0 && duration == 0 { + continue + } + totalTables += len(tables) + totalSize += tables.size() + totalRead += read + totalWrite += write + totalDuration += duration + value += fmt.Sprintf(" %3d | %10d | %13.5f | %13.5f | %13.5f | %13.5f\n", + level, len(tables), float64(tables.size())/1048576.0, duration.Seconds(), + float64(read)/1048576.0, float64(write)/1048576.0) + } + value += "-------+------------+---------------+---------------+---------------+---------------\n" + value += fmt.Sprintf(" Total | %10d | %13.5f | %13.5f | %13.5f | %13.5f\n", + totalTables, float64(totalSize)/1048576.0, totalDuration.Seconds(), + float64(totalRead)/1048576.0, float64(totalWrite)/1048576.0) + case p == "compcount": + value = fmt.Sprintf("MemComp:%d Level0Comp:%d NonLevel0Comp:%d SeekComp:%d", atomic.LoadUint32(&db.memComp), atomic.LoadUint32(&db.level0Comp), atomic.LoadUint32(&db.nonLevel0Comp), atomic.LoadUint32(&db.seekComp)) + case p == "iostats": + value = fmt.Sprintf("Read(MB):%.5f Write(MB):%.5f", + float64(db.s.stor.reads())/1048576.0, + float64(db.s.stor.writes())/1048576.0) + case p == "writedelay": + writeDelayN, writeDelay := atomic.LoadInt32(&db.cWriteDelayN), time.Duration(atomic.LoadInt64(&db.cWriteDelay)) + paused := atomic.LoadInt32(&db.inWritePaused) == 1 + value = fmt.Sprintf("DelayN:%d Delay:%s Paused:%t", writeDelayN, writeDelay, paused) + case p == "sstables": + for level, tables := range v.levels { + value += fmt.Sprintf("--- level %d ---\n", level) + for _, t := range tables { + value += fmt.Sprintf("%d:%d[%q .. %q]\n", t.fd.Num, t.size, t.imin, t.imax) + } + } + case p == "blockpool": + value = fmt.Sprintf("%v", db.s.tops.bpool) + case p == "cachedblock": + if db.s.tops.bcache != nil { + value = fmt.Sprintf("%d", db.s.tops.bcache.Size()) + } else { + value = "" + } + case p == "openedtables": + value = fmt.Sprintf("%d", db.s.tops.cache.Size()) + case p == "alivesnaps": + value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveSnaps)) + case p == "aliveiters": + value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveIters)) + default: + err = ErrNotFound + } + + return +} + +// DBStats is database statistics. +type DBStats struct { + WriteDelayCount int32 + WriteDelayDuration time.Duration + WritePaused bool + + AliveSnapshots int32 + AliveIterators int32 + + IOWrite uint64 + IORead uint64 + + BlockCacheSize int + OpenedTablesCount int + + LevelSizes Sizes + LevelTablesCounts []int + LevelRead Sizes + LevelWrite Sizes + LevelDurations []time.Duration + + MemComp uint32 + Level0Comp uint32 + NonLevel0Comp uint32 + SeekComp uint32 +} + +// Stats populates s with database statistics. +func (db *DB) Stats(s *DBStats) error { + err := db.ok() + if err != nil { + return err + } + + s.IORead = db.s.stor.reads() + s.IOWrite = db.s.stor.writes() + s.WriteDelayCount = atomic.LoadInt32(&db.cWriteDelayN) + s.WriteDelayDuration = time.Duration(atomic.LoadInt64(&db.cWriteDelay)) + s.WritePaused = atomic.LoadInt32(&db.inWritePaused) == 1 + + s.OpenedTablesCount = db.s.tops.cache.Size() + if db.s.tops.bcache != nil { + s.BlockCacheSize = db.s.tops.bcache.Size() + } else { + s.BlockCacheSize = 0 + } + + s.AliveIterators = atomic.LoadInt32(&db.aliveIters) + s.AliveSnapshots = atomic.LoadInt32(&db.aliveSnaps) + + s.LevelDurations = s.LevelDurations[:0] + s.LevelRead = s.LevelRead[:0] + s.LevelWrite = s.LevelWrite[:0] + s.LevelSizes = s.LevelSizes[:0] + s.LevelTablesCounts = s.LevelTablesCounts[:0] + + v := db.s.version() + defer v.release() + + for level, tables := range v.levels { + duration, read, write := db.compStats.getStat(level) + + s.LevelDurations = append(s.LevelDurations, duration) + s.LevelRead = append(s.LevelRead, read) + s.LevelWrite = append(s.LevelWrite, write) + s.LevelSizes = append(s.LevelSizes, tables.size()) + s.LevelTablesCounts = append(s.LevelTablesCounts, len(tables)) + } + s.MemComp = atomic.LoadUint32(&db.memComp) + s.Level0Comp = atomic.LoadUint32(&db.level0Comp) + s.NonLevel0Comp = atomic.LoadUint32(&db.nonLevel0Comp) + s.SeekComp = atomic.LoadUint32(&db.seekComp) + return nil +} + +// SizeOf calculates approximate sizes of the given key ranges. +// The length of the returned sizes are equal with the length of the given +// ranges. The returned sizes measure storage space usage, so if the user +// data compresses by a factor of ten, the returned sizes will be one-tenth +// the size of the corresponding user data size. +// The results may not include the sizes of recently written data. +func (db *DB) SizeOf(ranges []util.Range) (Sizes, error) { + if err := db.ok(); err != nil { + return nil, err + } + + v := db.s.version() + defer v.release() + + sizes := make(Sizes, 0, len(ranges)) + for _, r := range ranges { + imin := makeInternalKey(nil, r.Start, keyMaxSeq, keyTypeSeek) + imax := makeInternalKey(nil, r.Limit, keyMaxSeq, keyTypeSeek) + start, err := v.offsetOf(imin) + if err != nil { + return nil, err + } + limit, err := v.offsetOf(imax) + if err != nil { + return nil, err + } + var size int64 + if limit >= start { + size = limit - start + } + sizes = append(sizes, size) + } + + return sizes, nil +} + +// Close closes the DB. This will also releases any outstanding snapshot, +// abort any in-flight compaction and discard open transaction. +// +// It is not safe to close a DB until all outstanding iterators are released. +// It is valid to call Close multiple times. Other methods should not be +// called after the DB has been closed. +func (db *DB) Close() error { + if !db.setClosed() { + return ErrClosed + } + + start := time.Now() + db.log("db@close closing") + + // Clear the finalizer. + runtime.SetFinalizer(db, nil) + + // Get compaction error. + var err error + select { + case err = <-db.compErrC: + if err == ErrReadOnly { + err = nil + } + default: + } + + // Signal all goroutines. + close(db.closeC) + + // Discard open transaction. + if db.tr != nil { + db.tr.Discard() + } + + // Acquire writer lock. + db.writeLockC <- struct{}{} + + // Wait for all gorotines to exit. + db.closeW.Wait() + + // Closes journal. + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + db.journal = nil + db.journalWriter = nil + } + + if db.writeDelayN > 0 { + db.logf("db@write was delayed N·%d T·%v", db.writeDelayN, db.writeDelay) + } + + // Close session. + db.s.close() + db.logf("db@close done T·%v", time.Since(start)) + db.s.release() + + if db.closer != nil { + if err1 := db.closer.Close(); err == nil { + err = err1 + } + db.closer = nil + } + + // Clear memdbs. + db.clearMems() + + return err +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go new file mode 100644 index 0000000..6b70eb2 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go @@ -0,0 +1,865 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +var ( + errCompactionTransactExiting = errors.New("leveldb: compaction transact exiting") +) + +type cStat struct { + duration time.Duration + read int64 + write int64 +} + +func (p *cStat) add(n *cStatStaging) { + p.duration += n.duration + p.read += n.read + p.write += n.write +} + +func (p *cStat) get() (duration time.Duration, read, write int64) { + return p.duration, p.read, p.write +} + +type cStatStaging struct { + start time.Time + duration time.Duration + on bool + read int64 + write int64 +} + +func (p *cStatStaging) startTimer() { + if !p.on { + p.start = time.Now() + p.on = true + } +} + +func (p *cStatStaging) stopTimer() { + if p.on { + p.duration += time.Since(p.start) + p.on = false + } +} + +type cStats struct { + lk sync.Mutex + stats []cStat +} + +func (p *cStats) addStat(level int, n *cStatStaging) { + p.lk.Lock() + if level >= len(p.stats) { + newStats := make([]cStat, level+1) + copy(newStats, p.stats) + p.stats = newStats + } + p.stats[level].add(n) + p.lk.Unlock() +} + +func (p *cStats) getStat(level int) (duration time.Duration, read, write int64) { + p.lk.Lock() + defer p.lk.Unlock() + if level < len(p.stats) { + return p.stats[level].get() + } + return +} + +func (db *DB) compactionError() { + var err error +noerr: + // No error. + for { + select { + case err = <-db.compErrSetC: + switch { + case err == nil: + case err == ErrReadOnly, errors.IsCorrupted(err): + goto hasperr + default: + goto haserr + } + case <-db.closeC: + return + } + } +haserr: + // Transient error. + for { + select { + case db.compErrC <- err: + case err = <-db.compErrSetC: + switch { + case err == nil: + goto noerr + case err == ErrReadOnly, errors.IsCorrupted(err): + goto hasperr + default: + } + case <-db.closeC: + return + } + } +hasperr: + // Persistent error. + for { + select { + case db.compErrC <- err: + case db.compPerErrC <- err: + case db.writeLockC <- struct{}{}: + // Hold write lock, so that write won't pass-through. + db.compWriteLocking = true + case <-db.closeC: + if db.compWriteLocking { + // We should release the lock or Close will hang. + <-db.writeLockC + } + return + } + } +} + +type compactionTransactCounter int + +func (cnt *compactionTransactCounter) incr() { + *cnt++ +} + +type compactionTransactInterface interface { + run(cnt *compactionTransactCounter) error + revert() error +} + +func (db *DB) compactionTransact(name string, t compactionTransactInterface) { + defer func() { + if x := recover(); x != nil { + if x == errCompactionTransactExiting { + if err := t.revert(); err != nil { + db.logf("%s revert error %q", name, err) + } + } + panic(x) + } + }() + + const ( + backoffMin = 1 * time.Second + backoffMax = 8 * time.Second + backoffMul = 2 * time.Second + ) + var ( + backoff = backoffMin + backoffT = time.NewTimer(backoff) + lastCnt = compactionTransactCounter(0) + + disableBackoff = db.s.o.GetDisableCompactionBackoff() + ) + for n := 0; ; n++ { + // Check whether the DB is closed. + if db.isClosed() { + db.logf("%s exiting", name) + db.compactionExitTransact() + } else if n > 0 { + db.logf("%s retrying N·%d", name, n) + } + + // Execute. + cnt := compactionTransactCounter(0) + err := t.run(&cnt) + if err != nil { + db.logf("%s error I·%d %q", name, cnt, err) + } + + // Set compaction error status. + select { + case db.compErrSetC <- err: + case perr := <-db.compPerErrC: + if err != nil { + db.logf("%s exiting (persistent error %q)", name, perr) + db.compactionExitTransact() + } + case <-db.closeC: + db.logf("%s exiting", name) + db.compactionExitTransact() + } + if err == nil { + return + } + if errors.IsCorrupted(err) { + db.logf("%s exiting (corruption detected)", name) + db.compactionExitTransact() + } + + if !disableBackoff { + // Reset backoff duration if counter is advancing. + if cnt > lastCnt { + backoff = backoffMin + lastCnt = cnt + } + + // Backoff. + backoffT.Reset(backoff) + if backoff < backoffMax { + backoff *= backoffMul + if backoff > backoffMax { + backoff = backoffMax + } + } + select { + case <-backoffT.C: + case <-db.closeC: + db.logf("%s exiting", name) + db.compactionExitTransact() + } + } + } +} + +type compactionTransactFunc struct { + runFunc func(cnt *compactionTransactCounter) error + revertFunc func() error +} + +func (t *compactionTransactFunc) run(cnt *compactionTransactCounter) error { + return t.runFunc(cnt) +} + +func (t *compactionTransactFunc) revert() error { + if t.revertFunc != nil { + return t.revertFunc() + } + return nil +} + +func (db *DB) compactionTransactFunc(name string, run func(cnt *compactionTransactCounter) error, revert func() error) { + db.compactionTransact(name, &compactionTransactFunc{run, revert}) +} + +func (db *DB) compactionExitTransact() { + panic(errCompactionTransactExiting) +} + +func (db *DB) compactionCommit(name string, rec *sessionRecord) { + db.compCommitLk.Lock() + defer db.compCommitLk.Unlock() // Defer is necessary. + db.compactionTransactFunc(name+"@commit", func(cnt *compactionTransactCounter) error { + return db.s.commit(rec, true) + }, nil) +} + +func (db *DB) memCompaction() { + mdb := db.getFrozenMem() + if mdb == nil { + return + } + defer mdb.decref() + + db.logf("memdb@flush N·%d S·%s", mdb.Len(), shortenb(mdb.Size())) + + // Don't compact empty memdb. + if mdb.Len() == 0 { + db.logf("memdb@flush skipping") + // drop frozen memdb + db.dropFrozenMem() + return + } + + // Pause table compaction. + resumeC := make(chan struct{}) + select { + case db.tcompPauseC <- (chan<- struct{})(resumeC): + case <-db.compPerErrC: + close(resumeC) + resumeC = nil + case <-db.closeC: + db.compactionExitTransact() + } + + var ( + rec = &sessionRecord{} + stats = &cStatStaging{} + flushLevel int + ) + + // Generate tables. + db.compactionTransactFunc("memdb@flush", func(cnt *compactionTransactCounter) (err error) { + stats.startTimer() + flushLevel, err = db.s.flushMemdb(rec, mdb.DB, db.memdbMaxLevel) + stats.stopTimer() + return + }, func() error { + for _, r := range rec.addedTables { + db.logf("memdb@flush revert @%d", r.num) + if err := db.s.stor.Remove(storage.FileDesc{Type: storage.TypeTable, Num: r.num}); err != nil { + return err + } + } + return nil + }) + + rec.setJournalNum(db.journalFd.Num) + rec.setSeqNum(db.frozenSeq) + + // Commit. + stats.startTimer() + db.compactionCommit("memdb", rec) + stats.stopTimer() + + db.logf("memdb@flush committed F·%d T·%v", len(rec.addedTables), stats.duration) + + // Save compaction stats + for _, r := range rec.addedTables { + stats.write += r.size + } + db.compStats.addStat(flushLevel, stats) + atomic.AddUint32(&db.memComp, 1) + + // Drop frozen memdb. + db.dropFrozenMem() + + // Resume table compaction. + if resumeC != nil { + select { + case <-resumeC: + close(resumeC) + case <-db.closeC: + db.compactionExitTransact() + } + } + + // Trigger table compaction. + db.compTrigger(db.tcompCmdC) +} + +type tableCompactionBuilder struct { + db *DB + s *session + c *compaction + rec *sessionRecord + stat0, stat1 *cStatStaging + + snapHasLastUkey bool + snapLastUkey []byte + snapLastSeq uint64 + snapIter int + snapKerrCnt int + snapDropCnt int + + kerrCnt int + dropCnt int + + minSeq uint64 + strict bool + tableSize int + + tw *tWriter +} + +func (b *tableCompactionBuilder) appendKV(key, value []byte) error { + // Create new table if not already. + if b.tw == nil { + // Check for pause event. + if b.db != nil { + select { + case ch := <-b.db.tcompPauseC: + b.db.pauseCompaction(ch) + case <-b.db.closeC: + b.db.compactionExitTransact() + default: + } + } + + // Create new table. + var err error + b.tw, err = b.s.tops.create() + if err != nil { + return err + } + } + + // Write key/value into table. + return b.tw.append(key, value) +} + +func (b *tableCompactionBuilder) needFlush() bool { + return b.tw.tw.BytesLen() >= b.tableSize +} + +func (b *tableCompactionBuilder) flush() error { + t, err := b.tw.finish() + if err != nil { + return err + } + b.rec.addTableFile(b.c.sourceLevel+1, t) + b.stat1.write += t.size + b.s.logf("table@build created L%d@%d N·%d S·%s %q:%q", b.c.sourceLevel+1, t.fd.Num, b.tw.tw.EntriesLen(), shortenb(int(t.size)), t.imin, t.imax) + b.tw = nil + return nil +} + +func (b *tableCompactionBuilder) cleanup() { + if b.tw != nil { + b.tw.drop() + b.tw = nil + } +} + +func (b *tableCompactionBuilder) run(cnt *compactionTransactCounter) error { + snapResumed := b.snapIter > 0 + hasLastUkey := b.snapHasLastUkey // The key might has zero length, so this is necessary. + lastUkey := append([]byte{}, b.snapLastUkey...) + lastSeq := b.snapLastSeq + b.kerrCnt = b.snapKerrCnt + b.dropCnt = b.snapDropCnt + // Restore compaction state. + b.c.restore() + + defer b.cleanup() + + b.stat1.startTimer() + defer b.stat1.stopTimer() + + iter := b.c.newIterator() + defer iter.Release() + for i := 0; iter.Next(); i++ { + // Incr transact counter. + cnt.incr() + + // Skip until last state. + if i < b.snapIter { + continue + } + + resumed := false + if snapResumed { + resumed = true + snapResumed = false + } + + ikey := iter.Key() + ukey, seq, kt, kerr := parseInternalKey(ikey) + + if kerr == nil { + shouldStop := !resumed && b.c.shouldStopBefore(ikey) + + if !hasLastUkey || b.s.icmp.uCompare(lastUkey, ukey) != 0 { + // First occurrence of this user key. + + // Only rotate tables if ukey doesn't hop across. + if b.tw != nil && (shouldStop || b.needFlush()) { + if err := b.flush(); err != nil { + return err + } + + // Creates snapshot of the state. + b.c.save() + b.snapHasLastUkey = hasLastUkey + b.snapLastUkey = append(b.snapLastUkey[:0], lastUkey...) + b.snapLastSeq = lastSeq + b.snapIter = i + b.snapKerrCnt = b.kerrCnt + b.snapDropCnt = b.dropCnt + } + + hasLastUkey = true + lastUkey = append(lastUkey[:0], ukey...) + lastSeq = keyMaxSeq + } + + switch { + case lastSeq <= b.minSeq: + // Dropped because newer entry for same user key exist + fallthrough // (A) + case kt == keyTypeDel && seq <= b.minSeq && b.c.baseLevelForKey(lastUkey): + // For this user key: + // (1) there is no data in higher levels + // (2) data in lower levels will have larger seq numbers + // (3) data in layers that are being compacted here and have + // smaller seq numbers will be dropped in the next + // few iterations of this loop (by rule (A) above). + // Therefore this deletion marker is obsolete and can be dropped. + lastSeq = seq + b.dropCnt++ + continue + default: + lastSeq = seq + } + } else { + if b.strict { + return kerr + } + + // Don't drop corrupted keys. + hasLastUkey = false + lastUkey = lastUkey[:0] + lastSeq = keyMaxSeq + b.kerrCnt++ + } + + if err := b.appendKV(ikey, iter.Value()); err != nil { + return err + } + } + + if err := iter.Error(); err != nil { + return err + } + + // Finish last table. + if b.tw != nil && !b.tw.empty() { + return b.flush() + } + return nil +} + +func (b *tableCompactionBuilder) revert() error { + for _, at := range b.rec.addedTables { + b.s.logf("table@build revert @%d", at.num) + if err := b.s.stor.Remove(storage.FileDesc{Type: storage.TypeTable, Num: at.num}); err != nil { + return err + } + } + return nil +} + +func (db *DB) tableCompaction(c *compaction, noTrivial bool) { + defer c.release() + + rec := &sessionRecord{} + rec.addCompPtr(c.sourceLevel, c.imax) + + if !noTrivial && c.trivial() { + t := c.levels[0][0] + db.logf("table@move L%d@%d -> L%d", c.sourceLevel, t.fd.Num, c.sourceLevel+1) + rec.delTable(c.sourceLevel, t.fd.Num) + rec.addTableFile(c.sourceLevel+1, t) + db.compactionCommit("table-move", rec) + return + } + + var stats [2]cStatStaging + for i, tables := range c.levels { + for _, t := range tables { + stats[i].read += t.size + // Insert deleted tables into record + rec.delTable(c.sourceLevel+i, t.fd.Num) + } + } + sourceSize := int(stats[0].read + stats[1].read) + minSeq := db.minSeq() + db.logf("table@compaction L%d·%d -> L%d·%d S·%s Q·%d", c.sourceLevel, len(c.levels[0]), c.sourceLevel+1, len(c.levels[1]), shortenb(sourceSize), minSeq) + + b := &tableCompactionBuilder{ + db: db, + s: db.s, + c: c, + rec: rec, + stat1: &stats[1], + minSeq: minSeq, + strict: db.s.o.GetStrict(opt.StrictCompaction), + tableSize: db.s.o.GetCompactionTableSize(c.sourceLevel + 1), + } + db.compactionTransact("table@build", b) + + // Commit. + stats[1].startTimer() + db.compactionCommit("table", rec) + stats[1].stopTimer() + + resultSize := int(stats[1].write) + db.logf("table@compaction committed F%s S%s Ke·%d D·%d T·%v", sint(len(rec.addedTables)-len(rec.deletedTables)), sshortenb(resultSize-sourceSize), b.kerrCnt, b.dropCnt, stats[1].duration) + + // Save compaction stats + for i := range stats { + db.compStats.addStat(c.sourceLevel+1, &stats[i]) + } + switch c.typ { + case level0Compaction: + atomic.AddUint32(&db.level0Comp, 1) + case nonLevel0Compaction: + atomic.AddUint32(&db.nonLevel0Comp, 1) + case seekCompaction: + atomic.AddUint32(&db.seekComp, 1) + } +} + +func (db *DB) tableRangeCompaction(level int, umin, umax []byte) error { + db.logf("table@compaction range L%d %q:%q", level, umin, umax) + if level >= 0 { + if c := db.s.getCompactionRange(level, umin, umax, true); c != nil { + db.tableCompaction(c, true) + } + } else { + // Retry until nothing to compact. + for { + compacted := false + + // Scan for maximum level with overlapped tables. + v := db.s.version() + m := 1 + for i := m; i < len(v.levels); i++ { + tables := v.levels[i] + if tables.overlaps(db.s.icmp, umin, umax, false) { + m = i + } + } + v.release() + + for level := 0; level < m; level++ { + if c := db.s.getCompactionRange(level, umin, umax, false); c != nil { + db.tableCompaction(c, true) + compacted = true + } + } + + if !compacted { + break + } + } + } + + return nil +} + +func (db *DB) tableAutoCompaction() { + if c := db.s.pickCompaction(); c != nil { + db.tableCompaction(c, false) + } +} + +func (db *DB) tableNeedCompaction() bool { + v := db.s.version() + defer v.release() + return v.needCompaction() +} + +// resumeWrite returns an indicator whether we should resume write operation if enough level0 files are compacted. +func (db *DB) resumeWrite() bool { + v := db.s.version() + defer v.release() + if v.tLen(0) < db.s.o.GetWriteL0PauseTrigger() { + return true + } + return false +} + +func (db *DB) pauseCompaction(ch chan<- struct{}) { + select { + case ch <- struct{}{}: + case <-db.closeC: + db.compactionExitTransact() + } +} + +type cCmd interface { + ack(err error) +} + +type cAuto struct { + // Note for table compaction, an non-empty ackC represents it's a compaction waiting command. + ackC chan<- error +} + +func (r cAuto) ack(err error) { + if r.ackC != nil { + defer func() { + recover() + }() + r.ackC <- err + } +} + +type cRange struct { + level int + min, max []byte + ackC chan<- error +} + +func (r cRange) ack(err error) { + if r.ackC != nil { + defer func() { + recover() + }() + r.ackC <- err + } +} + +// This will trigger auto compaction but will not wait for it. +func (db *DB) compTrigger(compC chan<- cCmd) { + select { + case compC <- cAuto{}: + default: + } +} + +// This will trigger auto compaction and/or wait for all compaction to be done. +func (db *DB) compTriggerWait(compC chan<- cCmd) (err error) { + ch := make(chan error) + defer close(ch) + // Send cmd. + select { + case compC <- cAuto{ch}: + case err = <-db.compErrC: + return + case <-db.closeC: + return ErrClosed + } + // Wait cmd. + select { + case err = <-ch: + case err = <-db.compErrC: + case <-db.closeC: + return ErrClosed + } + return err +} + +// Send range compaction request. +func (db *DB) compTriggerRange(compC chan<- cCmd, level int, min, max []byte) (err error) { + ch := make(chan error) + defer close(ch) + // Send cmd. + select { + case compC <- cRange{level, min, max, ch}: + case err := <-db.compErrC: + return err + case <-db.closeC: + return ErrClosed + } + // Wait cmd. + select { + case err = <-ch: + case err = <-db.compErrC: + case <-db.closeC: + return ErrClosed + } + return err +} + +func (db *DB) mCompaction() { + var x cCmd + + defer func() { + if x := recover(); x != nil { + if x != errCompactionTransactExiting { + panic(x) + } + } + if x != nil { + x.ack(ErrClosed) + } + db.closeW.Done() + }() + + for { + select { + case x = <-db.mcompCmdC: + switch x.(type) { + case cAuto: + db.memCompaction() + x.ack(nil) + x = nil + default: + panic("leveldb: unknown command") + } + case <-db.closeC: + return + } + } +} + +func (db *DB) tCompaction() { + var ( + x cCmd + waitQ []cCmd + ) + + defer func() { + if x := recover(); x != nil { + if x != errCompactionTransactExiting { + panic(x) + } + } + for i := range waitQ { + waitQ[i].ack(ErrClosed) + waitQ[i] = nil + } + if x != nil { + x.ack(ErrClosed) + } + db.closeW.Done() + }() + + for { + if db.tableNeedCompaction() { + select { + case x = <-db.tcompCmdC: + case ch := <-db.tcompPauseC: + db.pauseCompaction(ch) + continue + case <-db.closeC: + return + default: + } + // Resume write operation as soon as possible. + if len(waitQ) > 0 && db.resumeWrite() { + for i := range waitQ { + waitQ[i].ack(nil) + waitQ[i] = nil + } + waitQ = waitQ[:0] + } + } else { + for i := range waitQ { + waitQ[i].ack(nil) + waitQ[i] = nil + } + waitQ = waitQ[:0] + select { + case x = <-db.tcompCmdC: + case ch := <-db.tcompPauseC: + db.pauseCompaction(ch) + continue + case <-db.closeC: + return + } + } + if x != nil { + switch cmd := x.(type) { + case cAuto: + if cmd.ackC != nil { + // Check the write pause state before caching it. + if db.resumeWrite() { + x.ack(nil) + } else { + waitQ = append(waitQ, x) + } + } + case cRange: + x.ack(db.tableRangeCompaction(cmd.level, cmd.min, cmd.max)) + default: + panic("leveldb: unknown command") + } + x = nil + } + db.tableAutoCompaction() + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go new file mode 100644 index 0000000..e6e8ca5 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go @@ -0,0 +1,369 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "errors" + "math/rand" + "runtime" + "sync" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + errInvalidInternalKey = errors.New("leveldb: Iterator: invalid internal key") +) + +type memdbReleaser struct { + once sync.Once + m *memDB +} + +func (mr *memdbReleaser) Release() { + mr.once.Do(func() { + mr.m.decref() + }) +} + +func (db *DB) newRawIterator(auxm *memDB, auxt tFiles, slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + strict := opt.GetStrict(db.s.o.Options, ro, opt.StrictReader) + em, fm := db.getMems() + v := db.s.version() + + tableIts := v.getIterators(slice, ro) + n := len(tableIts) + len(auxt) + 3 + its := make([]iterator.Iterator, 0, n) + + if auxm != nil { + ami := auxm.NewIterator(slice) + ami.SetReleaser(&memdbReleaser{m: auxm}) + its = append(its, ami) + } + for _, t := range auxt { + its = append(its, v.s.tops.newIterator(t, slice, ro)) + } + + emi := em.NewIterator(slice) + emi.SetReleaser(&memdbReleaser{m: em}) + its = append(its, emi) + if fm != nil { + fmi := fm.NewIterator(slice) + fmi.SetReleaser(&memdbReleaser{m: fm}) + its = append(its, fmi) + } + its = append(its, tableIts...) + mi := iterator.NewMergedIterator(its, db.s.icmp, strict) + mi.SetReleaser(&versionReleaser{v: v}) + return mi +} + +func (db *DB) newIterator(auxm *memDB, auxt tFiles, seq uint64, slice *util.Range, ro *opt.ReadOptions) *dbIter { + var islice *util.Range + if slice != nil { + islice = &util.Range{} + if slice.Start != nil { + islice.Start = makeInternalKey(nil, slice.Start, keyMaxSeq, keyTypeSeek) + } + if slice.Limit != nil { + islice.Limit = makeInternalKey(nil, slice.Limit, keyMaxSeq, keyTypeSeek) + } + } + rawIter := db.newRawIterator(auxm, auxt, islice, ro) + iter := &dbIter{ + db: db, + icmp: db.s.icmp, + iter: rawIter, + seq: seq, + strict: opt.GetStrict(db.s.o.Options, ro, opt.StrictReader), + disableSampling: db.s.o.GetDisableSeeksCompaction() || db.s.o.GetIteratorSamplingRate() <= 0, + key: make([]byte, 0), + value: make([]byte, 0), + } + if !iter.disableSampling { + iter.samplingGap = db.iterSamplingRate() + } + atomic.AddInt32(&db.aliveIters, 1) + runtime.SetFinalizer(iter, (*dbIter).Release) + return iter +} + +func (db *DB) iterSamplingRate() int { + return rand.Intn(2 * db.s.o.GetIteratorSamplingRate()) +} + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +// dbIter represent an interator states over a database session. +type dbIter struct { + db *DB + icmp *iComparer + iter iterator.Iterator + seq uint64 + strict bool + disableSampling bool + + samplingGap int + dir dir + key []byte + value []byte + err error + releaser util.Releaser +} + +func (i *dbIter) sampleSeek() { + if i.disableSampling { + return + } + + ikey := i.iter.Key() + i.samplingGap -= len(ikey) + len(i.iter.Value()) + for i.samplingGap < 0 { + i.samplingGap += i.db.iterSamplingRate() + i.db.sampleSeek(ikey) + } +} + +func (i *dbIter) setErr(err error) { + i.err = err + i.key = nil + i.value = nil +} + +func (i *dbIter) iterErr() { + if err := i.iter.Error(); err != nil { + i.setErr(err) + } +} + +func (i *dbIter) Valid() bool { + return i.err == nil && i.dir > dirEOI +} + +func (i *dbIter) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.iter.First() { + i.dir = dirSOI + return i.next() + } + i.dir = dirEOI + i.iterErr() + return false +} + +func (i *dbIter) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.iter.Last() { + return i.prev() + } + i.dir = dirSOI + i.iterErr() + return false +} + +func (i *dbIter) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + ikey := makeInternalKey(nil, key, i.seq, keyTypeSeek) + if i.iter.Seek(ikey) { + i.dir = dirSOI + return i.next() + } + i.dir = dirEOI + i.iterErr() + return false +} + +func (i *dbIter) next() bool { + for { + if ukey, seq, kt, kerr := parseInternalKey(i.iter.Key()); kerr == nil { + i.sampleSeek() + if seq <= i.seq { + switch kt { + case keyTypeDel: + // Skip deleted key. + i.key = append(i.key[:0], ukey...) + i.dir = dirForward + case keyTypeVal: + if i.dir == dirSOI || i.icmp.uCompare(ukey, i.key) > 0 { + i.key = append(i.key[:0], ukey...) + i.value = append(i.value[:0], i.iter.Value()...) + i.dir = dirForward + return true + } + } + } + } else if i.strict { + i.setErr(kerr) + break + } + if !i.iter.Next() { + i.dir = dirEOI + i.iterErr() + break + } + } + return false +} + +func (i *dbIter) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if !i.iter.Next() || (i.dir == dirBackward && !i.iter.Next()) { + i.dir = dirEOI + i.iterErr() + return false + } + return i.next() +} + +func (i *dbIter) prev() bool { + i.dir = dirBackward + del := true + if i.iter.Valid() { + for { + if ukey, seq, kt, kerr := parseInternalKey(i.iter.Key()); kerr == nil { + i.sampleSeek() + if seq <= i.seq { + if !del && i.icmp.uCompare(ukey, i.key) < 0 { + return true + } + del = (kt == keyTypeDel) + if !del { + i.key = append(i.key[:0], ukey...) + i.value = append(i.value[:0], i.iter.Value()...) + } + } + } else if i.strict { + i.setErr(kerr) + return false + } + if !i.iter.Prev() { + break + } + } + } + if del { + i.dir = dirSOI + i.iterErr() + return false + } + return true +} + +func (i *dbIter) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirEOI: + return i.Last() + case dirForward: + for i.iter.Prev() { + if ukey, _, _, kerr := parseInternalKey(i.iter.Key()); kerr == nil { + i.sampleSeek() + if i.icmp.uCompare(ukey, i.key) < 0 { + goto cont + } + } else if i.strict { + i.setErr(kerr) + return false + } + } + i.dir = dirSOI + i.iterErr() + return false + } + +cont: + return i.prev() +} + +func (i *dbIter) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.key +} + +func (i *dbIter) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.value +} + +func (i *dbIter) Release() { + if i.dir != dirReleased { + // Clear the finalizer. + runtime.SetFinalizer(i, nil) + + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + + i.dir = dirReleased + i.key = nil + i.value = nil + i.iter.Release() + i.iter = nil + atomic.AddInt32(&i.db.aliveIters, -1) + i.db = nil + } +} + +func (i *dbIter) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *dbIter) Error() error { + return i.err +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_snapshot.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_snapshot.go new file mode 100644 index 0000000..c2ad70c --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_snapshot.go @@ -0,0 +1,187 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "container/list" + "fmt" + "runtime" + "sync" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type snapshotElement struct { + seq uint64 + ref int + e *list.Element +} + +// Acquires a snapshot, based on latest sequence. +func (db *DB) acquireSnapshot() *snapshotElement { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + seq := db.getSeq() + + if e := db.snapsList.Back(); e != nil { + se := e.Value.(*snapshotElement) + if se.seq == seq { + se.ref++ + return se + } else if seq < se.seq { + panic("leveldb: sequence number is not increasing") + } + } + se := &snapshotElement{seq: seq, ref: 1} + se.e = db.snapsList.PushBack(se) + return se +} + +// Releases given snapshot element. +func (db *DB) releaseSnapshot(se *snapshotElement) { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + se.ref-- + if se.ref == 0 { + db.snapsList.Remove(se.e) + se.e = nil + } else if se.ref < 0 { + panic("leveldb: Snapshot: negative element reference") + } +} + +// Gets minimum sequence that not being snapshotted. +func (db *DB) minSeq() uint64 { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + if e := db.snapsList.Front(); e != nil { + return e.Value.(*snapshotElement).seq + } + + return db.getSeq() +} + +// Snapshot is a DB snapshot. +type Snapshot struct { + db *DB + elem *snapshotElement + mu sync.RWMutex + released bool +} + +// Creates new snapshot object. +func (db *DB) newSnapshot() *Snapshot { + snap := &Snapshot{ + db: db, + elem: db.acquireSnapshot(), + } + atomic.AddInt32(&db.aliveSnaps, 1) + runtime.SetFinalizer(snap, (*Snapshot).Release) + return snap +} + +func (snap *Snapshot) String() string { + return fmt.Sprintf("leveldb.Snapshot{%d}", snap.elem.seq) +} + +// Get gets the value for the given key. It returns ErrNotFound if +// the DB does not contains the key. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Get returns. +func (snap *Snapshot) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + err = snap.db.ok() + if err != nil { + return + } + snap.mu.RLock() + defer snap.mu.RUnlock() + if snap.released { + err = ErrSnapshotReleased + return + } + return snap.db.get(nil, nil, key, snap.elem.seq, ro) +} + +// Has returns true if the DB does contains the given key. +// +// It is safe to modify the contents of the argument after Get returns. +func (snap *Snapshot) Has(key []byte, ro *opt.ReadOptions) (ret bool, err error) { + err = snap.db.ok() + if err != nil { + return + } + snap.mu.RLock() + defer snap.mu.RUnlock() + if snap.released { + err = ErrSnapshotReleased + return + } + return snap.db.has(nil, nil, key, snap.elem.seq, ro) +} + +// NewIterator returns an iterator for the snapshot of the underlying DB. +// The returned iterator is not safe for concurrent use, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. The resultant key/value pairs are guaranteed to be +// consistent. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Value() methods), its content should not be +// modified unless noted otherwise. +// +// The iterator must be released after use, by calling Release method. +// Releasing the snapshot doesn't mean releasing the iterator too, the +// iterator would be still valid until released. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (snap *Snapshot) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + if err := snap.db.ok(); err != nil { + return iterator.NewEmptyIterator(err) + } + snap.mu.Lock() + defer snap.mu.Unlock() + if snap.released { + return iterator.NewEmptyIterator(ErrSnapshotReleased) + } + // Since iterator already hold version ref, it doesn't need to + // hold snapshot ref. + return snap.db.newIterator(nil, nil, snap.elem.seq, slice, ro) +} + +// Release releases the snapshot. This will not release any returned +// iterators, the iterators would still be valid until released or the +// underlying DB is closed. +// +// Other methods should not be called after the snapshot has been released. +func (snap *Snapshot) Release() { + snap.mu.Lock() + defer snap.mu.Unlock() + + if !snap.released { + // Clear the finalizer. + runtime.SetFinalizer(snap, nil) + + snap.released = true + snap.db.releaseSnapshot(snap.elem) + atomic.AddInt32(&snap.db.aliveSnaps, -1) + snap.db = nil + snap.elem = nil + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_state.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_state.go new file mode 100644 index 0000000..65e1c54 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_state.go @@ -0,0 +1,239 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "errors" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +var ( + errHasFrozenMem = errors.New("has frozen mem") +) + +type memDB struct { + db *DB + *memdb.DB + ref int32 +} + +func (m *memDB) getref() int32 { + return atomic.LoadInt32(&m.ref) +} + +func (m *memDB) incref() { + atomic.AddInt32(&m.ref, 1) +} + +func (m *memDB) decref() { + if ref := atomic.AddInt32(&m.ref, -1); ref == 0 { + // Only put back memdb with std capacity. + if m.Capacity() == m.db.s.o.GetWriteBuffer() { + m.Reset() + m.db.mpoolPut(m.DB) + } + m.db = nil + m.DB = nil + } else if ref < 0 { + panic("negative memdb ref") + } +} + +// Get latest sequence number. +func (db *DB) getSeq() uint64 { + return atomic.LoadUint64(&db.seq) +} + +// Atomically adds delta to seq. +func (db *DB) addSeq(delta uint64) { + atomic.AddUint64(&db.seq, delta) +} + +func (db *DB) setSeq(seq uint64) { + atomic.StoreUint64(&db.seq, seq) +} + +func (db *DB) sampleSeek(ikey internalKey) { + v := db.s.version() + if v.sampleSeek(ikey) { + // Trigger table compaction. + db.compTrigger(db.tcompCmdC) + } + v.release() +} + +func (db *DB) mpoolPut(mem *memdb.DB) { + if !db.isClosed() { + select { + case db.memPool <- mem: + default: + } + } +} + +func (db *DB) mpoolGet(n int) *memDB { + var mdb *memdb.DB + select { + case mdb = <-db.memPool: + default: + } + if mdb == nil || mdb.Capacity() < n { + mdb = memdb.New(db.s.icmp, maxInt(db.s.o.GetWriteBuffer(), n)) + } + return &memDB{ + db: db, + DB: mdb, + } +} + +func (db *DB) mpoolDrain() { + ticker := time.NewTicker(30 * time.Second) + for { + select { + case <-ticker.C: + select { + case <-db.memPool: + default: + } + case <-db.closeC: + ticker.Stop() + // Make sure the pool is drained. + select { + case <-db.memPool: + case <-time.After(time.Second): + } + close(db.memPool) + return + } + } +} + +// Create new memdb and froze the old one; need external synchronization. +// newMem only called synchronously by the writer. +func (db *DB) newMem(n int) (mem *memDB, err error) { + fd := storage.FileDesc{Type: storage.TypeJournal, Num: db.s.allocFileNum()} + w, err := db.s.stor.Create(fd) + if err != nil { + db.s.reuseFileNum(fd.Num) + return + } + + db.memMu.Lock() + defer db.memMu.Unlock() + + if db.frozenMem != nil { + return nil, errHasFrozenMem + } + + if db.journal == nil { + db.journal = journal.NewWriter(w) + } else { + db.journal.Reset(w) + db.journalWriter.Close() + db.frozenJournalFd = db.journalFd + } + db.journalWriter = w + db.journalFd = fd + db.frozenMem = db.mem + mem = db.mpoolGet(n) + mem.incref() // for self + mem.incref() // for caller + db.mem = mem + // The seq only incremented by the writer. And whoever called newMem + // should hold write lock, so no need additional synchronization here. + db.frozenSeq = db.seq + return +} + +// Get all memdbs. +func (db *DB) getMems() (e, f *memDB) { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.mem != nil { + db.mem.incref() + } else if !db.isClosed() { + panic("nil effective mem") + } + if db.frozenMem != nil { + db.frozenMem.incref() + } + return db.mem, db.frozenMem +} + +// Get effective memdb. +func (db *DB) getEffectiveMem() *memDB { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.mem != nil { + db.mem.incref() + } else if !db.isClosed() { + panic("nil effective mem") + } + return db.mem +} + +// Check whether we has frozen memdb. +func (db *DB) hasFrozenMem() bool { + db.memMu.RLock() + defer db.memMu.RUnlock() + return db.frozenMem != nil +} + +// Get frozen memdb. +func (db *DB) getFrozenMem() *memDB { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.frozenMem != nil { + db.frozenMem.incref() + } + return db.frozenMem +} + +// Drop frozen memdb; assume that frozen memdb isn't nil. +func (db *DB) dropFrozenMem() { + db.memMu.Lock() + if err := db.s.stor.Remove(db.frozenJournalFd); err != nil { + db.logf("journal@remove removing @%d %q", db.frozenJournalFd.Num, err) + } else { + db.logf("journal@remove removed @%d", db.frozenJournalFd.Num) + } + db.frozenJournalFd = storage.FileDesc{} + db.frozenMem.decref() + db.frozenMem = nil + db.memMu.Unlock() +} + +// Clear mems ptr; used by DB.Close(). +func (db *DB) clearMems() { + db.memMu.Lock() + db.mem = nil + db.frozenMem = nil + db.memMu.Unlock() +} + +// Set closed flag; return true if not already closed. +func (db *DB) setClosed() bool { + return atomic.CompareAndSwapUint32(&db.closed, 0, 1) +} + +// Check whether DB was closed. +func (db *DB) isClosed() bool { + return atomic.LoadUint32(&db.closed) != 0 +} + +// Check read ok status. +func (db *DB) ok() error { + if db.isClosed() { + return ErrClosed + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_test.go new file mode 100644 index 0000000..5749b4d --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_test.go @@ -0,0 +1,2926 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "container/list" + crand "crypto/rand" + "encoding/binary" + "fmt" + "math/rand" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func tkey(i int) []byte { + return []byte(fmt.Sprintf("%016d", i)) +} + +func tval(seed, n int) []byte { + r := rand.New(rand.NewSource(int64(seed))) + return randomString(r, n) +} + +func testingLogger(t *testing.T) func(log string) { + return func(log string) { + t.Log(log) + } +} + +func testingPreserveOnFailed(t *testing.T) func() (preserve bool, err error) { + return func() (preserve bool, err error) { + preserve = t.Failed() + return + } +} + +type dbHarness struct { + t *testing.T + + stor *testutil.Storage + db *DB + o *opt.Options + ro *opt.ReadOptions + wo *opt.WriteOptions +} + +func newDbHarnessWopt(t *testing.T, o *opt.Options) *dbHarness { + h := new(dbHarness) + h.init(t, o) + return h +} + +func newDbHarness(t *testing.T) *dbHarness { + return newDbHarnessWopt(t, &opt.Options{DisableLargeBatchTransaction: true}) +} + +func (h *dbHarness) init(t *testing.T, o *opt.Options) { + gomega.RegisterTestingT(t) + h.t = t + h.stor = testutil.NewStorage() + h.stor.OnLog(testingLogger(t)) + h.stor.OnClose(testingPreserveOnFailed(t)) + h.o = o + h.ro = nil + h.wo = nil + + if err := h.openDB0(); err != nil { + // So that it will come after fatal message. + defer h.stor.Close() + h.t.Fatal("Open (init): got error: ", err) + } +} + +func (h *dbHarness) openDB0() (err error) { + h.t.Log("opening DB") + h.db, err = Open(h.stor, h.o) + return +} + +func (h *dbHarness) openDB() { + if err := h.openDB0(); err != nil { + h.t.Fatal("Open: got error: ", err) + } +} + +func (h *dbHarness) closeDB0() error { + h.t.Log("closing DB") + return h.db.Close() +} + +func (h *dbHarness) closeDB() { + if h.db != nil { + if err := h.closeDB0(); err != nil { + h.t.Error("Close: got error: ", err) + } + } + h.stor.CloseCheck() + runtime.GC() +} + +func (h *dbHarness) reopenDB() { + if h.db != nil { + h.closeDB() + } + h.openDB() +} + +func (h *dbHarness) close() { + if h.db != nil { + h.closeDB0() + h.db = nil + } + h.stor.Close() + h.stor = nil + runtime.GC() +} + +func (h *dbHarness) openAssert(want bool) { + db, err := Open(h.stor, h.o) + if err != nil { + if want { + h.t.Error("Open: assert: got error: ", err) + } else { + h.t.Log("Open: assert: got error (expected): ", err) + } + } else { + if !want { + h.t.Error("Open: assert: expect error") + } + db.Close() + } +} + +func (h *dbHarness) write(batch *Batch) { + if err := h.db.Write(batch, h.wo); err != nil { + h.t.Error("Write: got error: ", err) + } +} + +func (h *dbHarness) put(key, value string) { + if err := h.db.Put([]byte(key), []byte(value), h.wo); err != nil { + h.t.Error("Put: got error: ", err) + } +} + +func (h *dbHarness) putMulti(n int, low, hi string) { + for i := 0; i < n; i++ { + h.put(low, "begin") + h.put(hi, "end") + h.compactMem() + } +} + +func (h *dbHarness) maxNextLevelOverlappingBytes(want int64) { + t := h.t + db := h.db + + var ( + maxOverlaps int64 + maxLevel int + ) + v := db.s.version() + if len(v.levels) > 2 { + for i, tt := range v.levels[1 : len(v.levels)-1] { + level := i + 1 + next := v.levels[level+1] + for _, t := range tt { + r := next.getOverlaps(nil, db.s.icmp, t.imin.ukey(), t.imax.ukey(), false) + sum := r.size() + if sum > maxOverlaps { + maxOverlaps = sum + maxLevel = level + } + } + } + } + v.release() + + if maxOverlaps > want { + t.Errorf("next level most overlapping bytes is more than %d, got=%d level=%d", want, maxOverlaps, maxLevel) + } else { + t.Logf("next level most overlapping bytes is %d, level=%d want=%d", maxOverlaps, maxLevel, want) + } +} + +func (h *dbHarness) delete(key string) { + t := h.t + db := h.db + + err := db.Delete([]byte(key), h.wo) + if err != nil { + t.Error("Delete: got error: ", err) + } +} + +func (h *dbHarness) assertNumKeys(want int) { + iter := h.db.NewIterator(nil, h.ro) + defer iter.Release() + got := 0 + for iter.Next() { + got++ + } + if err := iter.Error(); err != nil { + h.t.Error("assertNumKeys: ", err) + } + if want != got { + h.t.Errorf("assertNumKeys: want=%d got=%d", want, got) + } +} + +func (h *dbHarness) getr(db Reader, key string, expectFound bool) (found bool, v []byte) { + t := h.t + v, err := db.Get([]byte(key), h.ro) + switch err { + case ErrNotFound: + if expectFound { + t.Errorf("Get: key '%s' not found, want found", key) + } + case nil: + found = true + if !expectFound { + t.Errorf("Get: key '%s' found, want not found", key) + } + default: + t.Error("Get: got error: ", err) + } + return +} + +func (h *dbHarness) get(key string, expectFound bool) (found bool, v []byte) { + return h.getr(h.db, key, expectFound) +} + +func (h *dbHarness) getValr(db Reader, key, value string) { + t := h.t + found, r := h.getr(db, key, true) + if !found { + return + } + rval := string(r) + if rval != value { + t.Errorf("Get: invalid value, got '%s', want '%s'", rval, value) + } +} + +func (h *dbHarness) getVal(key, value string) { + h.getValr(h.db, key, value) +} + +func (h *dbHarness) allEntriesFor(key, want string) { + t := h.t + db := h.db + s := db.s + + ikey := makeInternalKey(nil, []byte(key), keyMaxSeq, keyTypeVal) + iter := db.newRawIterator(nil, nil, nil, nil) + if !iter.Seek(ikey) && iter.Error() != nil { + t.Error("AllEntries: error during seek, err: ", iter.Error()) + return + } + res := "[ " + first := true + for iter.Valid() { + if ukey, _, kt, kerr := parseInternalKey(iter.Key()); kerr == nil { + if s.icmp.uCompare(ikey.ukey(), ukey) != 0 { + break + } + if !first { + res += ", " + } + first = false + switch kt { + case keyTypeVal: + res += string(iter.Value()) + case keyTypeDel: + res += "DEL" + } + } else { + if !first { + res += ", " + } + first = false + res += "CORRUPTED" + } + iter.Next() + } + if !first { + res += " " + } + res += "]" + if res != want { + t.Errorf("AllEntries: assert failed for key %q, got=%q want=%q", key, res, want) + } +} + +// Return a string that contains all key,value pairs in order, +// formatted like "(k1->v1)(k2->v2)". +func (h *dbHarness) getKeyVal(want string) { + t := h.t + db := h.db + + s, err := db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + res := "" + iter := s.NewIterator(nil, nil) + for iter.Next() { + res += fmt.Sprintf("(%s->%s)", string(iter.Key()), string(iter.Value())) + } + iter.Release() + + if res != want { + t.Errorf("GetKeyVal: invalid key/value pair, got=%q want=%q", res, want) + } + s.Release() +} + +func (h *dbHarness) waitCompaction() { + t := h.t + db := h.db + if err := db.compTriggerWait(db.tcompCmdC); err != nil { + t.Error("compaction error: ", err) + } +} + +func (h *dbHarness) waitMemCompaction() { + t := h.t + db := h.db + + if err := db.compTriggerWait(db.mcompCmdC); err != nil { + t.Error("compaction error: ", err) + } +} + +func (h *dbHarness) compactMem() { + t := h.t + db := h.db + + t.Log("starting memdb compaction") + + db.writeLockC <- struct{}{} + defer func() { + <-db.writeLockC + }() + + if _, err := db.rotateMem(0, true); err != nil { + t.Error("compaction error: ", err) + } + + if h.totalTables() == 0 { + t.Error("zero tables after mem compaction") + } + + t.Log("memdb compaction done") +} + +func (h *dbHarness) compactRangeAtErr(level int, min, max string, wanterr bool) { + t := h.t + db := h.db + + var _min, _max []byte + if min != "" { + _min = []byte(min) + } + if max != "" { + _max = []byte(max) + } + + t.Logf("starting table range compaction: level=%d, min=%q, max=%q", level, min, max) + + if err := db.compTriggerRange(db.tcompCmdC, level, _min, _max); err != nil { + if wanterr { + t.Log("CompactRangeAt: got error (expected): ", err) + } else { + t.Error("CompactRangeAt: got error: ", err) + } + } else if wanterr { + t.Error("CompactRangeAt: expect error") + } + + t.Log("table range compaction done") +} + +func (h *dbHarness) compactRangeAt(level int, min, max string) { + h.compactRangeAtErr(level, min, max, false) +} + +func (h *dbHarness) compactRange(min, max string) { + t := h.t + db := h.db + + t.Logf("starting DB range compaction: min=%q, max=%q", min, max) + + var r util.Range + if min != "" { + r.Start = []byte(min) + } + if max != "" { + r.Limit = []byte(max) + } + if err := db.CompactRange(r); err != nil { + t.Error("CompactRange: got error: ", err) + } + + t.Log("DB range compaction done") +} + +func (h *dbHarness) sizeOf(start, limit string) int64 { + sz, err := h.db.SizeOf([]util.Range{ + {Start: []byte(start), Limit: []byte(limit)}, + }) + if err != nil { + h.t.Error("SizeOf: got error: ", err) + } + return sz.Sum() +} + +func (h *dbHarness) sizeAssert(start, limit string, low, hi int64) { + sz := h.sizeOf(start, limit) + if sz < low || sz > hi { + h.t.Errorf("sizeOf %q to %q not in range, want %d - %d, got %d", + shorten(start), shorten(limit), low, hi, sz) + } +} + +func (h *dbHarness) getSnapshot() (s *Snapshot) { + s, err := h.db.GetSnapshot() + if err != nil { + h.t.Fatal("GetSnapshot: got error: ", err) + } + return +} + +func (h *dbHarness) getTablesPerLevel() string { + res := "" + nz := 0 + v := h.db.s.version() + for level, tables := range v.levels { + if level > 0 { + res += "," + } + res += fmt.Sprint(len(tables)) + if len(tables) > 0 { + nz = len(res) + } + } + v.release() + return res[:nz] +} + +func (h *dbHarness) tablesPerLevel(want string) { + res := h.getTablesPerLevel() + if res != want { + h.t.Errorf("invalid tables len, want=%s, got=%s", want, res) + } +} + +func (h *dbHarness) totalTables() (n int) { + v := h.db.s.version() + for _, tables := range v.levels { + n += len(tables) + } + v.release() + return +} + +type keyValue interface { + Key() []byte + Value() []byte +} + +func testKeyVal(t *testing.T, kv keyValue, want string) { + res := string(kv.Key()) + "->" + string(kv.Value()) + if res != want { + t.Errorf("invalid key/value, want=%q, got=%q", want, res) + } +} + +func numKey(num int) string { + return fmt.Sprintf("key%06d", num) +} + +var testingBloomFilter = filter.NewBloomFilter(10) + +func truno(t *testing.T, o *opt.Options, f func(h *dbHarness)) { + for i := 0; i < 4; i++ { + func() { + switch i { + case 0: + case 1: + if o == nil { + o = &opt.Options{ + DisableLargeBatchTransaction: true, + Filter: testingBloomFilter, + } + } else { + old := o + o = &opt.Options{} + *o = *old + o.Filter = testingBloomFilter + } + case 2: + if o == nil { + o = &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + } + } else { + old := o + o = &opt.Options{} + *o = *old + o.Compression = opt.NoCompression + } + } + h := newDbHarnessWopt(t, o) + defer h.close() + switch i { + case 3: + h.reopenDB() + } + f(h) + }() + } +} + +func trun(t *testing.T, f func(h *dbHarness)) { + truno(t, nil, f) +} + +func testAligned(t *testing.T, name string, offset uintptr) { + if offset%8 != 0 { + t.Errorf("field %s offset is not 64-bit aligned", name) + } +} + +func Test_FieldsAligned(t *testing.T) { + p1 := new(DB) + testAligned(t, "DB.seq", unsafe.Offsetof(p1.seq)) + p2 := new(session) + testAligned(t, "session.stNextFileNum", unsafe.Offsetof(p2.stNextFileNum)) + testAligned(t, "session.stJournalNum", unsafe.Offsetof(p2.stJournalNum)) + testAligned(t, "session.stPrevJournalNum", unsafe.Offsetof(p2.stPrevJournalNum)) + testAligned(t, "session.stSeqNum", unsafe.Offsetof(p2.stSeqNum)) +} + +func TestDB_Locking(t *testing.T) { + h := newDbHarness(t) + defer h.stor.Close() + h.openAssert(false) + h.closeDB() + h.openAssert(true) +} + +func TestDB_Empty(t *testing.T) { + trun(t, func(h *dbHarness) { + h.get("foo", false) + + h.reopenDB() + h.get("foo", false) + }) +} + +func TestDB_ReadWrite(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.getVal("foo", "v1") + h.put("bar", "v2") + h.put("foo", "v3") + h.getVal("foo", "v3") + h.getVal("bar", "v2") + + h.reopenDB() + h.getVal("foo", "v3") + h.getVal("bar", "v2") + }) +} + +func TestDB_PutDeleteGet(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.getVal("foo", "v1") + h.put("foo", "v2") + h.getVal("foo", "v2") + h.delete("foo") + h.get("foo", false) + + h.reopenDB() + h.get("foo", false) + }) +} + +func TestDB_EmptyBatch(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.get("foo", false) + err := h.db.Write(new(Batch), h.wo) + if err != nil { + t.Error("writing empty batch yield error: ", err) + } + h.get("foo", false) +} + +func TestDB_GetFromFrozen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 100100, + }) + defer h.close() + + h.put("foo", "v1") + h.getVal("foo", "v1") + + h.stor.Stall(testutil.ModeSync, storage.TypeTable) // Block sync calls + h.put("k1", strings.Repeat("x", 100000)) // Fill memtable + h.put("k2", strings.Repeat("y", 100000)) // Trigger compaction + for i := 0; h.db.getFrozenMem() == nil && i < 100; i++ { + time.Sleep(10 * time.Microsecond) + } + if h.db.getFrozenMem() == nil { + h.stor.Release(testutil.ModeSync, storage.TypeTable) + t.Fatal("No frozen mem") + } + h.getVal("foo", "v1") + h.stor.Release(testutil.ModeSync, storage.TypeTable) // Release sync calls + + h.reopenDB() + h.getVal("foo", "v1") + h.get("k1", true) + h.get("k2", true) +} + +func TestDB_GetFromTable(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.compactMem() + h.getVal("foo", "v1") + }) +} + +func TestDB_GetSnapshot(t *testing.T) { + trun(t, func(h *dbHarness) { + bar := strings.Repeat("b", 200) + h.put("foo", "v1") + h.put(bar, "v1") + + snap, err := h.db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + + h.put("foo", "v2") + h.put(bar, "v2") + + h.getVal("foo", "v2") + h.getVal(bar, "v2") + h.getValr(snap, "foo", "v1") + h.getValr(snap, bar, "v1") + + h.compactMem() + + h.getVal("foo", "v2") + h.getVal(bar, "v2") + h.getValr(snap, "foo", "v1") + h.getValr(snap, bar, "v1") + + snap.Release() + + h.reopenDB() + h.getVal("foo", "v2") + h.getVal(bar, "v2") + }) +} + +func TestDB_GetLevel0Ordering(t *testing.T) { + trun(t, func(h *dbHarness) { + h.db.memdbMaxLevel = 2 + + for i := 0; i < 4; i++ { + h.put("bar", fmt.Sprintf("b%d", i)) + h.put("foo", fmt.Sprintf("v%d", i)) + h.compactMem() + } + h.getVal("foo", "v3") + h.getVal("bar", "b3") + + v := h.db.s.version() + t0len := v.tLen(0) + v.release() + if t0len < 2 { + t.Errorf("level-0 tables is less than 2, got %d", t0len) + } + + h.reopenDB() + h.getVal("foo", "v3") + h.getVal("bar", "b3") + }) +} + +func TestDB_GetOrderedByLevels(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.compactMem() + h.compactRange("a", "z") + h.getVal("foo", "v1") + h.put("foo", "v2") + h.compactMem() + h.getVal("foo", "v2") + }) +} + +func TestDB_GetPicksCorrectFile(t *testing.T) { + trun(t, func(h *dbHarness) { + // Arrange to have multiple files in a non-level-0 level. + h.put("a", "va") + h.compactMem() + h.compactRange("a", "b") + h.put("x", "vx") + h.compactMem() + h.compactRange("x", "y") + h.put("f", "vf") + h.compactMem() + h.compactRange("f", "g") + + h.getVal("a", "va") + h.getVal("f", "vf") + h.getVal("x", "vx") + + h.compactRange("", "") + h.getVal("a", "va") + h.getVal("f", "vf") + h.getVal("x", "vx") + }) +} + +func TestDB_GetEncountersEmptyLevel(t *testing.T) { + trun(t, func(h *dbHarness) { + h.db.memdbMaxLevel = 2 + + // Arrange for the following to happen: + // * sstable A in level 0 + // * nothing in level 1 + // * sstable B in level 2 + // Then do enough Get() calls to arrange for an automatic compaction + // of sstable A. A bug would cause the compaction to be marked as + // occurring at level 1 (instead of the correct level 0). + + // Step 1: First place sstables in levels 0 and 2 + for i := 0; ; i++ { + if i >= 100 { + t.Fatal("could not fill levels-0 and level-2") + } + v := h.db.s.version() + if v.tLen(0) > 0 && v.tLen(2) > 0 { + v.release() + break + } + v.release() + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + + h.getVal("a", "begin") + h.getVal("z", "end") + } + + // Step 2: clear level 1 if necessary. + h.compactRangeAt(1, "", "") + h.tablesPerLevel("1,0,1") + + h.getVal("a", "begin") + h.getVal("z", "end") + + // Step 3: read a bunch of times + for i := 0; i < 200; i++ { + h.get("missing", false) + } + + // Step 4: Wait for compaction to finish + h.waitCompaction() + + v := h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + v.release() + + h.getVal("a", "begin") + h.getVal("z", "end") + }) +} + +func TestDB_IterMultiWithDelete(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("a", "va") + h.put("b", "vb") + h.put("c", "vc") + h.delete("b") + h.get("b", false) + + iter := h.db.NewIterator(nil, nil) + iter.Seek([]byte("c")) + testKeyVal(t, iter, "c->vc") + iter.Prev() + testKeyVal(t, iter, "a->va") + iter.Release() + + h.compactMem() + + iter = h.db.NewIterator(nil, nil) + iter.Seek([]byte("c")) + testKeyVal(t, iter, "c->vc") + iter.Prev() + testKeyVal(t, iter, "a->va") + iter.Release() + }) +} + +func TestDB_IteratorPinsRef(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "hello") + + // Get iterator that will yield the current contents of the DB. + iter := h.db.NewIterator(nil, nil) + + // Write to force compactions + h.put("foo", "newvalue1") + for i := 0; i < 100; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } + h.put("foo", "newvalue2") + + iter.First() + testKeyVal(t, iter, "foo->hello") + if iter.Next() { + t.Errorf("expect eof") + } + iter.Release() +} + +func TestDB_Recover(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.put("baz", "v5") + + h.reopenDB() + h.getVal("foo", "v1") + + h.getVal("foo", "v1") + h.getVal("baz", "v5") + h.put("bar", "v2") + h.put("foo", "v3") + + h.reopenDB() + h.getVal("foo", "v3") + h.put("foo", "v4") + h.getVal("foo", "v4") + h.getVal("bar", "v2") + h.getVal("baz", "v5") + }) +} + +func TestDB_RecoverWithEmptyJournal(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.put("foo", "v2") + + h.reopenDB() + h.reopenDB() + h.put("foo", "v3") + + h.reopenDB() + h.getVal("foo", "v3") + }) +} + +func TestDB_RecoverDuringMemtableCompaction(t *testing.T) { + truno(t, &opt.Options{DisableLargeBatchTransaction: true, WriteBuffer: 1000000}, func(h *dbHarness) { + + h.stor.Stall(testutil.ModeSync, storage.TypeTable) + h.put("big1", strings.Repeat("x", 10000000)) + h.put("big2", strings.Repeat("y", 1000)) + h.put("bar", "v2") + h.stor.Release(testutil.ModeSync, storage.TypeTable) + + h.reopenDB() + h.getVal("bar", "v2") + h.getVal("big1", strings.Repeat("x", 10000000)) + h.getVal("big2", strings.Repeat("y", 1000)) + }) +} + +func TestDB_MinorCompactionsHappen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{DisableLargeBatchTransaction: true, WriteBuffer: 10000}) + defer h.close() + + n := 500 + + key := func(i int) string { + return fmt.Sprintf("key%06d", i) + } + + for i := 0; i < n; i++ { + h.put(key(i), key(i)+strings.Repeat("v", 1000)) + } + + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)+strings.Repeat("v", 1000)) + } + + h.reopenDB() + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)+strings.Repeat("v", 1000)) + } +} + +func TestDB_RecoverWithLargeJournal(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("big1", strings.Repeat("1", 200000)) + h.put("big2", strings.Repeat("2", 200000)) + h.put("small3", strings.Repeat("3", 10)) + h.put("small4", strings.Repeat("4", 10)) + h.tablesPerLevel("") + + // Make sure that if we re-open with a small write buffer size that + // we flush table files in the middle of a large journal file. + h.o.WriteBuffer = 100000 + h.reopenDB() + h.getVal("big1", strings.Repeat("1", 200000)) + h.getVal("big2", strings.Repeat("2", 200000)) + h.getVal("small3", strings.Repeat("3", 10)) + h.getVal("small4", strings.Repeat("4", 10)) + v := h.db.s.version() + if v.tLen(0) <= 1 { + t.Errorf("tables-0 less than one") + } + v.release() +} + +func TestDB_CompactionsGenerateMultipleFiles(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 10000000, + Compression: opt.NoCompression, + }) + defer h.close() + + v := h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + v.release() + + n := 80 + + // Write 8MB (80 values, each 100K) + for i := 0; i < n; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } + + // Reopening moves updates to level-0 + h.reopenDB() + h.compactRangeAt(0, "", "") + + v = h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + if v.tLen(1) <= 1 { + t.Errorf("level-1 tables less than 1, got %d", v.tLen(1)) + } + v.release() + + for i := 0; i < n; i++ { + h.getVal(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } +} + +func TestDB_RepeatedWritesToSameKey(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{DisableLargeBatchTransaction: true, WriteBuffer: 100000}) + defer h.close() + + maxTables := h.o.GetWriteL0PauseTrigger() + 7 + + value := strings.Repeat("v", 2*h.o.GetWriteBuffer()) + for i := 0; i < 5*maxTables; i++ { + h.put("key", value) + n := h.totalTables() + if n > maxTables { + t.Errorf("total tables exceed %d, got=%d, iter=%d", maxTables, n, i) + } + } +} + +func TestDB_RepeatedWritesToSameKeyAfterReopen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 100000, + }) + defer h.close() + + h.reopenDB() + + maxTables := h.o.GetWriteL0PauseTrigger() + 7 + + value := strings.Repeat("v", 2*h.o.GetWriteBuffer()) + for i := 0; i < 5*maxTables; i++ { + h.put("key", value) + n := h.totalTables() + if n > maxTables { + t.Errorf("total tables exceed %d, got=%d, iter=%d", maxTables, n, i) + } + } +} + +func TestDB_SparseMerge(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{DisableLargeBatchTransaction: true, Compression: opt.NoCompression}) + defer h.close() + + h.putMulti(7, "A", "Z") + + // Suppose there is: + // small amount of data with prefix A + // large amount of data with prefix B + // small amount of data with prefix C + // and that recent updates have made small changes to all three prefixes. + // Check that we do not do a compaction that merges all of B in one shot. + h.put("A", "va") + value := strings.Repeat("x", 1000) + for i := 0; i < 100000; i++ { + h.put(fmt.Sprintf("B%010d", i), value) + } + h.put("C", "vc") + h.compactMem() + h.compactRangeAt(0, "", "") + h.waitCompaction() + + // Make sparse update + h.put("A", "va2") + h.put("B100", "bvalue2") + h.put("C", "vc2") + h.compactMem() + + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) + h.compactRangeAt(0, "", "") + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) + h.compactRangeAt(1, "", "") + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) +} + +func TestDB_SizeOf(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + WriteBuffer: 10000000, + }) + defer h.close() + + h.sizeAssert("", "xyz", 0, 0) + h.reopenDB() + h.sizeAssert("", "xyz", 0, 0) + + // Write 8MB (80 values, each 100K) + n := 80 + s1 := 100000 + s2 := 105000 + + for i := 0; i < n; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), s1/10)) + } + + // 0 because SizeOf() does not account for memtable space + h.sizeAssert("", numKey(50), 0, 0) + + for r := 0; r < 3; r++ { + h.reopenDB() + + for cs := 0; cs < n; cs += 10 { + for i := 0; i < n; i += 10 { + h.sizeAssert("", numKey(i), int64(s1*i), int64(s2*i)) + h.sizeAssert("", numKey(i)+".suffix", int64(s1*(i+1)), int64(s2*(i+1))) + h.sizeAssert(numKey(i), numKey(i+10), int64(s1*10), int64(s2*10)) + } + + h.sizeAssert("", numKey(50), int64(s1*50), int64(s2*50)) + h.sizeAssert("", numKey(50)+".suffix", int64(s1*50), int64(s2*50)) + + h.compactRangeAt(0, numKey(cs), numKey(cs+9)) + } + + v := h.db.s.version() + if v.tLen(0) != 0 { + t.Errorf("level-0 tables was not zero, got %d", v.tLen(0)) + } + if v.tLen(1) == 0 { + t.Error("level-1 tables was zero") + } + v.release() + } +} + +func TestDB_SizeOf_MixOfSmallAndLarge(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + }) + defer h.close() + + sizes := []int64{ + 10000, + 10000, + 100000, + 10000, + 100000, + 10000, + 300000, + 10000, + } + + for i, n := range sizes { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), int(n)/10)) + } + + for r := 0; r < 3; r++ { + h.reopenDB() + + var x int64 + for i, n := range sizes { + y := x + if i > 0 { + y += 1000 + } + h.sizeAssert("", numKey(i), x, y) + x += n + } + + h.sizeAssert(numKey(3), numKey(5), 110000, 111000) + + h.compactRangeAt(0, "", "") + } +} + +func TestDB_Snapshot(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + s1 := h.getSnapshot() + h.put("foo", "v2") + s2 := h.getSnapshot() + h.put("foo", "v3") + s3 := h.getSnapshot() + h.put("foo", "v4") + + h.getValr(s1, "foo", "v1") + h.getValr(s2, "foo", "v2") + h.getValr(s3, "foo", "v3") + h.getVal("foo", "v4") + + s3.Release() + h.getValr(s1, "foo", "v1") + h.getValr(s2, "foo", "v2") + h.getVal("foo", "v4") + + s1.Release() + h.getValr(s2, "foo", "v2") + h.getVal("foo", "v4") + + s2.Release() + h.getVal("foo", "v4") + }) +} + +func TestDB_SnapshotList(t *testing.T) { + db := &DB{snapsList: list.New()} + e0a := db.acquireSnapshot() + e0b := db.acquireSnapshot() + db.seq = 1 + e1 := db.acquireSnapshot() + db.seq = 2 + e2 := db.acquireSnapshot() + + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e0a) + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e2) + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e0b) + if db.minSeq() != 1 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + e2 = db.acquireSnapshot() + if db.minSeq() != 1 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e1) + if db.minSeq() != 2 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e2) + if db.minSeq() != 2 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } +} + +func TestDB_HiddenValuesAreRemoved(t *testing.T) { + trun(t, func(h *dbHarness) { + s := h.db.s + + m := 2 + h.db.memdbMaxLevel = m + + h.put("foo", "v1") + h.compactMem() + v := s.version() + num := v.tLen(m) + v.release() + if num != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, num) + } + + // Place a table at level last-1 to prevent merging with preceding mutation + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + v = s.version() + if v.tLen(m) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, v.tLen(m)) + } + if v.tLen(m-1) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m-1, v.tLen(m-1)) + } + v.release() + + h.delete("foo") + h.put("foo", "v2") + h.allEntriesFor("foo", "[ v2, DEL, v1 ]") + h.compactMem() + h.allEntriesFor("foo", "[ v2, DEL, v1 ]") + h.compactRangeAt(m-2, "", "z") + // DEL eliminated, but v1 remains because we aren't compacting that level + // (DEL can be eliminated because v2 hides v1). + h.allEntriesFor("foo", "[ v2, v1 ]") + h.compactRangeAt(m-1, "", "") + // Merging last-1 w/ last, so we are the base level for "foo", so + // DEL is removed. (as is v1). + h.allEntriesFor("foo", "[ v2 ]") + }) +} + +func TestDB_DeletionMarkers2(t *testing.T) { + h := newDbHarness(t) + defer h.close() + s := h.db.s + + m := 2 + h.db.memdbMaxLevel = m + + h.put("foo", "v1") + h.compactMem() + v := s.version() + num := v.tLen(m) + v.release() + if num != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, num) + } + + // Place a table at level last-1 to prevent merging with preceding mutation + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + v = s.version() + if v.tLen(m) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, v.tLen(m)) + } + if v.tLen(m-1) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m-1, v.tLen(m-1)) + } + v.release() + + h.delete("foo") + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactMem() // Moves to level last-2 + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactRangeAt(m-2, "", "") + // DEL kept: "last" file overlaps + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactRangeAt(m-1, "", "") + // Merging last-1 w/ last, so we are the base level for "foo", so + // DEL is removed. (as is v1). + h.allEntriesFor("foo", "[ ]") +} + +func TestDB_CompactionTableOpenError(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + OpenFilesCacheCapacity: -1, + }) + defer h.close() + + h.db.memdbMaxLevel = 2 + + im := 10 + jm := 10 + for r := 0; r < 2; r++ { + for i := 0; i < im; i++ { + for j := 0; j < jm; j++ { + h.put(fmt.Sprintf("k%d,%d", i, j), fmt.Sprintf("v%d,%d", i, j)) + } + h.compactMem() + } + } + + if n := h.totalTables(); n != im*2 { + t.Errorf("total tables is %d, want %d", n, im*2) + } + + h.stor.EmulateError(testutil.ModeOpen, storage.TypeTable, errors.New("open error during table compaction")) + go h.db.CompactRange(util.Range{}) + if err := h.db.compTriggerWait(h.db.tcompCmdC); err != nil { + t.Log("compaction error: ", err) + } + h.closeDB0() + h.openDB() + h.stor.EmulateError(testutil.ModeOpen, storage.TypeTable, nil) + + for i := 0; i < im; i++ { + for j := 0; j < jm; j++ { + h.getVal(fmt.Sprintf("k%d,%d", i, j), fmt.Sprintf("v%d,%d", i, j)) + } + } +} + +func TestDB_OverlapInLevel0(t *testing.T) { + trun(t, func(h *dbHarness) { + h.db.memdbMaxLevel = 2 + + // Fill levels 1 and 2 to disable the pushing of new memtables to levels > 0. + h.put("100", "v100") + h.put("999", "v999") + h.compactMem() + h.delete("100") + h.delete("999") + h.compactMem() + h.tablesPerLevel("0,1,1") + + // Make files spanning the following ranges in level-0: + // files[0] 200 .. 900 + // files[1] 300 .. 500 + // Note that files are sorted by min key. + h.put("300", "v300") + h.put("500", "v500") + h.compactMem() + h.put("200", "v200") + h.put("600", "v600") + h.put("900", "v900") + h.compactMem() + h.tablesPerLevel("2,1,1") + + // Compact away the placeholder files we created initially + h.compactRangeAt(1, "", "") + h.compactRangeAt(2, "", "") + h.tablesPerLevel("2") + + // Do a memtable compaction. Before bug-fix, the compaction would + // not detect the overlap with level-0 files and would incorrectly place + // the deletion in a deeper level. + h.delete("600") + h.compactMem() + h.tablesPerLevel("3") + h.get("600", false) + }) +} + +func TestDB_L0_CompactionBug_Issue44_a(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.reopenDB() + h.put("b", "v") + h.reopenDB() + h.delete("b") + h.delete("a") + h.reopenDB() + h.delete("a") + h.reopenDB() + h.put("a", "v") + h.reopenDB() + h.reopenDB() + h.getKeyVal("(a->v)") + h.waitCompaction() + h.getKeyVal("(a->v)") +} + +func TestDB_L0_CompactionBug_Issue44_b(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.reopenDB() + h.put("", "") + h.reopenDB() + h.delete("e") + h.put("", "") + h.reopenDB() + h.put("c", "cv") + h.reopenDB() + h.put("", "") + h.reopenDB() + h.put("", "") + h.waitCompaction() + h.reopenDB() + h.put("d", "dv") + h.reopenDB() + h.put("", "") + h.reopenDB() + h.delete("d") + h.delete("b") + h.reopenDB() + h.getKeyVal("(->)(c->cv)") + h.waitCompaction() + h.getKeyVal("(->)(c->cv)") +} + +func TestDB_SingleEntryMemCompaction(t *testing.T) { + trun(t, func(h *dbHarness) { + for i := 0; i < 10; i++ { + h.put("big", strings.Repeat("v", opt.DefaultWriteBuffer)) + h.compactMem() + h.put("key", strings.Repeat("v", opt.DefaultBlockSize)) + h.compactMem() + h.put("k", "v") + h.compactMem() + h.put("", "") + h.compactMem() + h.put("verybig", strings.Repeat("v", opt.DefaultWriteBuffer*2)) + h.compactMem() + } + }) +} + +func TestDB_ManifestWriteError(t *testing.T) { + for i := 0; i < 2; i++ { + func() { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "bar") + h.getVal("foo", "bar") + + // Mem compaction (will succeed) + h.compactMem() + h.getVal("foo", "bar") + v := h.db.s.version() + if n := v.tLen(0); n != 1 { + t.Errorf("invalid total tables, want=1 got=%d", n) + } + v.release() + + if i == 0 { + h.stor.EmulateError(testutil.ModeWrite, storage.TypeManifest, errors.New("manifest write error")) + } else { + h.stor.EmulateError(testutil.ModeSync, storage.TypeManifest, errors.New("manifest sync error")) + } + + // Merging compaction (will fail) + h.compactRangeAtErr(0, "", "", true) + + h.db.Close() + h.stor.EmulateError(testutil.ModeWrite, storage.TypeManifest, nil) + h.stor.EmulateError(testutil.ModeSync, storage.TypeManifest, nil) + + // Should not lose data + h.openDB() + h.getVal("foo", "bar") + }() + } +} + +func assertErr(t *testing.T, err error, wanterr bool) { + if err != nil { + if wanterr { + t.Log("AssertErr: got error (expected): ", err) + } else { + t.Error("AssertErr: got error: ", err) + } + } else if wanterr { + t.Error("AssertErr: expect error") + } +} + +func TestDB_ClosedIsClosed(t *testing.T) { + h := newDbHarness(t) + db := h.db + + var iter, iter2 iterator.Iterator + var snap *Snapshot + func() { + defer h.close() + + h.put("k", "v") + h.getVal("k", "v") + + iter = db.NewIterator(nil, h.ro) + iter.Seek([]byte("k")) + testKeyVal(t, iter, "k->v") + + var err error + snap, err = db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + + h.getValr(snap, "k", "v") + + iter2 = snap.NewIterator(nil, h.ro) + iter2.Seek([]byte("k")) + testKeyVal(t, iter2, "k->v") + + h.put("foo", "v2") + h.delete("foo") + + // closing DB + iter.Release() + iter2.Release() + }() + + assertErr(t, db.Put([]byte("x"), []byte("y"), h.wo), true) + _, err := db.Get([]byte("k"), h.ro) + assertErr(t, err, true) + + if iter.Valid() { + t.Errorf("iter.Valid should false") + } + assertErr(t, iter.Error(), false) + testKeyVal(t, iter, "->") + if iter.Seek([]byte("k")) { + t.Errorf("iter.Seek should false") + } + assertErr(t, iter.Error(), true) + + assertErr(t, iter2.Error(), false) + + _, err = snap.Get([]byte("k"), h.ro) + assertErr(t, err, true) + + _, err = db.GetSnapshot() + assertErr(t, err, true) + + iter3 := db.NewIterator(nil, h.ro) + assertErr(t, iter3.Error(), true) + + iter3 = snap.NewIterator(nil, h.ro) + assertErr(t, iter3.Error(), true) + + assertErr(t, db.Delete([]byte("k"), h.wo), true) + + _, err = db.GetProperty("leveldb.stats") + assertErr(t, err, true) + + _, err = db.SizeOf([]util.Range{{Start: []byte("a"), Limit: []byte("z")}}) + assertErr(t, err, true) + + assertErr(t, db.CompactRange(util.Range{}), true) + + assertErr(t, db.Close(), true) +} + +type numberComparer struct{} + +func (numberComparer) num(x []byte) (n int) { + fmt.Sscan(string(x[1:len(x)-1]), &n) + return +} + +func (numberComparer) Name() string { + return "test.NumberComparer" +} + +func (p numberComparer) Compare(a, b []byte) int { + return p.num(a) - p.num(b) +} + +func (numberComparer) Separator(dst, a, b []byte) []byte { return nil } +func (numberComparer) Successor(dst, b []byte) []byte { return nil } + +func TestDB_CustomComparer(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Comparer: numberComparer{}, + WriteBuffer: 1000, + }) + defer h.close() + + h.put("[10]", "ten") + h.put("[0x14]", "twenty") + for i := 0; i < 2; i++ { + h.getVal("[10]", "ten") + h.getVal("[0xa]", "ten") + h.getVal("[20]", "twenty") + h.getVal("[0x14]", "twenty") + h.get("[15]", false) + h.get("[0xf]", false) + h.compactMem() + h.compactRange("[0]", "[9999]") + } + + for n := 0; n < 2; n++ { + for i := 0; i < 100; i++ { + v := fmt.Sprintf("[%d]", i*10) + h.put(v, v) + } + h.compactMem() + h.compactRange("[0]", "[1000000]") + } +} + +func TestDB_ManualCompaction(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.db.memdbMaxLevel = 2 + + h.putMulti(3, "p", "q") + h.tablesPerLevel("1,1,1") + + // Compaction range falls before files + h.compactRange("", "c") + h.tablesPerLevel("1,1,1") + + // Compaction range falls after files + h.compactRange("r", "z") + h.tablesPerLevel("1,1,1") + + // Compaction range overlaps files + h.compactRange("p1", "p9") + h.tablesPerLevel("0,0,1") + + // Populate a different range + h.putMulti(3, "c", "e") + h.tablesPerLevel("1,1,2") + + // Compact just the new range + h.compactRange("b", "f") + h.tablesPerLevel("0,0,2") + + // Compact all + h.putMulti(1, "a", "z") + h.tablesPerLevel("0,1,2") + h.compactRange("", "") + h.tablesPerLevel("0,0,1") +} + +func TestDB_BloomFilter(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + DisableBlockCache: true, + Filter: filter.NewBloomFilter(10), + }) + defer h.close() + + key := func(i int) string { + return fmt.Sprintf("key%06d", i) + } + + const n = 10000 + + // Populate multiple layers + for i := 0; i < n; i++ { + h.put(key(i), key(i)) + } + h.compactMem() + h.compactRange("a", "z") + for i := 0; i < n; i += 100 { + h.put(key(i), key(i)) + } + h.compactMem() + + // Prevent auto compactions triggered by seeks + h.stor.Stall(testutil.ModeSync, storage.TypeTable) + + // Lookup present keys. Should rarely read from small sstable. + h.stor.ResetCounter(testutil.ModeRead, storage.TypeTable) + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)) + } + cnt, _ := h.stor.Counter(testutil.ModeRead, storage.TypeTable) + t.Logf("lookup of %d present keys yield %d sstable I/O reads", n, cnt) + if min, max := n, n+2*n/100; cnt < min || cnt > max { + t.Errorf("num of sstable I/O reads of present keys not in range of %d - %d, got %d", min, max, cnt) + } + + // Lookup missing keys. Should rarely read from either sstable. + h.stor.ResetCounter(testutil.ModeRead, storage.TypeTable) + for i := 0; i < n; i++ { + h.get(key(i)+".missing", false) + } + cnt, _ = h.stor.Counter(testutil.ModeRead, storage.TypeTable) + t.Logf("lookup of %d missing keys yield %d sstable I/O reads", n, cnt) + if max := 3 * n / 100; cnt > max { + t.Errorf("num of sstable I/O reads of missing keys was more than %d, got %d", max, cnt) + } + + h.stor.Release(testutil.ModeSync, storage.TypeTable) +} + +func TestDB_Concurrent(t *testing.T) { + const n, secs, maxkey = 4, 6, 1000 + h := newDbHarness(t) + defer h.close() + + runtime.GOMAXPROCS(runtime.NumCPU()) + + var ( + closeWg sync.WaitGroup + stop uint32 + cnt [n]uint32 + ) + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + var put, get, found uint + defer func() { + t.Logf("goroutine %d stopped after %d ops, put=%d get=%d found=%d missing=%d", + i, cnt[i], put, get, found, get-found) + closeWg.Done() + }() + + rnd := rand.New(rand.NewSource(int64(1000 + i))) + for atomic.LoadUint32(&stop) == 0 { + x := cnt[i] + + k := rnd.Intn(maxkey) + kstr := fmt.Sprintf("%016d", k) + + if (rnd.Int() % 2) > 0 { + put++ + h.put(kstr, fmt.Sprintf("%d.%d.%-1000d", k, i, x)) + } else { + get++ + v, err := h.db.Get([]byte(kstr), h.ro) + if err == nil { + found++ + rk, ri, rx := 0, -1, uint32(0) + fmt.Sscanf(string(v), "%d.%d.%d", &rk, &ri, &rx) + if rk != k { + t.Errorf("invalid key want=%d got=%d", k, rk) + } + if ri < 0 || ri >= n { + t.Error("invalid goroutine number: ", ri) + } else { + tx := atomic.LoadUint32(&(cnt[ri])) + if rx > tx { + t.Errorf("invalid seq number, %d > %d ", rx, tx) + } + } + } else if err != ErrNotFound { + t.Error("Get: got error: ", err) + return + } + } + atomic.AddUint32(&cnt[i], 1) + } + }(i) + } + + time.Sleep(secs * time.Second) + atomic.StoreUint32(&stop, 1) + closeWg.Wait() +} + +func TestDB_ConcurrentIterator(t *testing.T) { + const n, n2 = 4, 1000 + h := newDbHarnessWopt(t, &opt.Options{DisableLargeBatchTransaction: true, WriteBuffer: 30}) + defer h.close() + + runtime.GOMAXPROCS(runtime.NumCPU()) + + var ( + closeWg sync.WaitGroup + stop uint32 + ) + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + for k := 0; atomic.LoadUint32(&stop) == 0; k++ { + h.put(fmt.Sprintf("k%d", k), fmt.Sprintf("%d.%d.", k, i)+strings.Repeat("x", 10)) + } + closeWg.Done() + }(i) + } + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + for k := 1000000; k < 0 || atomic.LoadUint32(&stop) == 0; k-- { + h.put(fmt.Sprintf("k%d", k), fmt.Sprintf("%d.%d.", k, i)+strings.Repeat("x", 10)) + } + closeWg.Done() + }(i) + } + + cmp := comparer.DefaultComparer + for i := 0; i < n2; i++ { + closeWg.Add(1) + go func(i int) { + it := h.db.NewIterator(nil, nil) + var pk []byte + for it.Next() { + kk := it.Key() + if cmp.Compare(kk, pk) <= 0 { + t.Errorf("iter %d: %q is successor of %q", i, pk, kk) + } + pk = append(pk[:0], kk...) + var k, vk, vi int + if n, err := fmt.Sscanf(string(it.Key()), "k%d", &k); err != nil { + t.Errorf("iter %d: Scanf error on key %q: %v", i, it.Key(), err) + } else if n < 1 { + t.Errorf("iter %d: Cannot parse key %q", i, it.Key()) + } + if n, err := fmt.Sscanf(string(it.Value()), "%d.%d", &vk, &vi); err != nil { + t.Errorf("iter %d: Scanf error on value %q: %v", i, it.Value(), err) + } else if n < 2 { + t.Errorf("iter %d: Cannot parse value %q", i, it.Value()) + } + + if vk != k { + t.Errorf("iter %d: invalid value i=%d, want=%d got=%d", i, vi, k, vk) + } + } + if err := it.Error(); err != nil { + t.Errorf("iter %d: Got error: %v", i, err) + } + it.Release() + closeWg.Done() + }(i) + } + + atomic.StoreUint32(&stop, 1) + closeWg.Wait() +} + +func TestDB_ConcurrentWrite(t *testing.T) { + const n, bk, niter = 10, 3, 10000 + h := newDbHarness(t) + defer h.close() + + runtime.GOMAXPROCS(runtime.NumCPU()) + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for k := 0; k < niter; k++ { + kstr := fmt.Sprintf("put-%d.%d", i, k) + vstr := fmt.Sprintf("v%d", k) + h.put(kstr, vstr) + // Key should immediately available after put returns. + h.getVal(kstr, vstr) + } + }(i) + } + for i := 0; i < n; i++ { + wg.Add(1) + batch := &Batch{} + go func(i int) { + defer wg.Done() + for k := 0; k < niter; k++ { + batch.Reset() + for j := 0; j < bk; j++ { + batch.Put([]byte(fmt.Sprintf("batch-%d.%d.%d", i, k, j)), []byte(fmt.Sprintf("v%d", k))) + } + h.write(batch) + // Key should immediately available after put returns. + for j := 0; j < bk; j++ { + h.getVal(fmt.Sprintf("batch-%d.%d.%d", i, k, j), fmt.Sprintf("v%d", k)) + } + } + }(i) + } + wg.Wait() +} + +func TestDB_CreateReopenDbOnFile(t *testing.T) { + dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbtestCreateReopenDbOnFile-%d", os.Getuid())) + if err := os.RemoveAll(dbpath); err != nil { + t.Fatal("cannot remove old db: ", err) + } + defer os.RemoveAll(dbpath) + + for i := 0; i < 3; i++ { + stor, err := storage.OpenFile(dbpath, false) + if err != nil { + t.Fatalf("(%d) cannot open storage: %s", i, err) + } + db, err := Open(stor, nil) + if err != nil { + t.Fatalf("(%d) cannot open db: %s", i, err) + } + if err := db.Put([]byte("foo"), []byte("bar"), nil); err != nil { + t.Fatalf("(%d) cannot write to db: %s", i, err) + } + if err := db.Close(); err != nil { + t.Fatalf("(%d) cannot close db: %s", i, err) + } + if err := stor.Close(); err != nil { + t.Fatalf("(%d) cannot close storage: %s", i, err) + } + } +} + +func TestDB_CreateReopenDbOnFile2(t *testing.T) { + dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbtestCreateReopenDbOnFile2-%d", os.Getuid())) + if err := os.RemoveAll(dbpath); err != nil { + t.Fatal("cannot remove old db: ", err) + } + defer os.RemoveAll(dbpath) + + for i := 0; i < 3; i++ { + db, err := OpenFile(dbpath, nil) + if err != nil { + t.Fatalf("(%d) cannot open db: %s", i, err) + } + if err := db.Put([]byte("foo"), []byte("bar"), nil); err != nil { + t.Fatalf("(%d) cannot write to db: %s", i, err) + } + if err := db.Close(); err != nil { + t.Fatalf("(%d) cannot close db: %s", i, err) + } + } +} + +func TestDB_DeletionMarkersOnMemdb(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "v1") + h.compactMem() + h.delete("foo") + h.get("foo", false) + h.getKeyVal("") +} + +func TestDB_LeveldbIssue178(t *testing.T) { + nKeys := (opt.DefaultCompactionTableSize / 30) * 5 + key1 := func(i int) string { + return fmt.Sprintf("my_key_%d", i) + } + key2 := func(i int) string { + return fmt.Sprintf("my_key_%d_xxx", i) + } + + // Disable compression since it affects the creation of layers and the + // code below is trying to test against a very specific scenario. + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + }) + defer h.close() + + // Create first key range. + batch := new(Batch) + for i := 0; i < nKeys; i++ { + batch.Put([]byte(key1(i)), []byte("value for range 1 key")) + } + h.write(batch) + + // Create second key range. + batch.Reset() + for i := 0; i < nKeys; i++ { + batch.Put([]byte(key2(i)), []byte("value for range 2 key")) + } + h.write(batch) + + // Delete second key range. + batch.Reset() + for i := 0; i < nKeys; i++ { + batch.Delete([]byte(key2(i))) + } + h.write(batch) + h.waitMemCompaction() + + // Run manual compaction. + h.compactRange(key1(0), key1(nKeys-1)) + + // Checking the keys. + h.assertNumKeys(nKeys) +} + +func TestDB_LeveldbIssue200(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("1", "b") + h.put("2", "c") + h.put("3", "d") + h.put("4", "e") + h.put("5", "f") + + iter := h.db.NewIterator(nil, h.ro) + + // Add an element that should not be reflected in the iterator. + h.put("25", "cd") + + iter.Seek([]byte("5")) + assertBytes(t, []byte("5"), iter.Key()) + iter.Prev() + assertBytes(t, []byte("4"), iter.Key()) + iter.Prev() + assertBytes(t, []byte("3"), iter.Key()) + iter.Next() + assertBytes(t, []byte("4"), iter.Key()) + iter.Next() + assertBytes(t, []byte("5"), iter.Key()) +} + +func TestDB_GoleveldbIssue74(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 1 * opt.MiB, + }) + defer h.close() + + const n, dur = 10000, 5 * time.Second + + runtime.GOMAXPROCS(runtime.NumCPU()) + + until := time.Now().Add(dur) + wg := new(sync.WaitGroup) + wg.Add(2) + var done uint32 + go func() { + var i int + defer func() { + t.Logf("WRITER DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + + b := new(Batch) + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + iv := fmt.Sprintf("VAL%010d", i) + for k := 0; k < n; k++ { + key := fmt.Sprintf("KEY%06d", k) + b.Put([]byte(key), []byte(key+iv)) + b.Put([]byte(fmt.Sprintf("PTR%06d", k)), []byte(key)) + } + h.write(b) + + b.Reset() + snap := h.getSnapshot() + iter := snap.NewIterator(util.BytesPrefix([]byte("PTR")), nil) + var k int + for ; iter.Next(); k++ { + ptrKey := iter.Key() + key := iter.Value() + + if _, err := snap.Get(ptrKey, nil); err != nil { + t.Fatalf("WRITER #%d snapshot.Get %q: %v", i, ptrKey, err) + } + if value, err := snap.Get(key, nil); err != nil { + t.Fatalf("WRITER #%d snapshot.Get %q: %v", i, key, err) + } else if string(value) != string(key)+iv { + t.Fatalf("WRITER #%d snapshot.Get %q got invalid value, want %q got %q", i, key, string(key)+iv, value) + } + + b.Delete(key) + b.Delete(ptrKey) + } + h.write(b) + iter.Release() + snap.Release() + if k != n { + t.Fatalf("#%d %d != %d", i, k, n) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + snap := h.getSnapshot() + iter := snap.NewIterator(util.BytesPrefix([]byte("PTR")), nil) + var prevValue string + var k int + for ; iter.Next(); k++ { + ptrKey := iter.Key() + key := iter.Value() + + if _, err := snap.Get(ptrKey, nil); err != nil { + t.Fatalf("READER #%d snapshot.Get %q: %v", i, ptrKey, err) + } + + if value, err := snap.Get(key, nil); err != nil { + t.Fatalf("READER #%d snapshot.Get %q: %v", i, key, err) + } else if prevValue != "" && string(value) != string(key)+prevValue { + t.Fatalf("READER #%d snapshot.Get %q got invalid value, want %q got %q", i, key, string(key)+prevValue, value) + } else { + prevValue = string(value[len(key):]) + } + } + iter.Release() + snap.Release() + if k > 0 && k != n { + t.Fatalf("#%d %d != %d", i, k, n) + } + } + }() + wg.Wait() +} + +func TestDB_GetProperties(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + _, err := h.db.GetProperty("leveldb.num-files-at-level") + if err == nil { + t.Error("GetProperty() failed to detect missing level") + } + + _, err = h.db.GetProperty("leveldb.num-files-at-level0") + if err != nil { + t.Error("got unexpected error", err) + } + + _, err = h.db.GetProperty("leveldb.num-files-at-level0x") + if err == nil { + t.Error("GetProperty() failed to detect invalid level") + } +} + +func TestDB_GoleveldbIssue72and83(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 1 * opt.MiB, + OpenFilesCacheCapacity: 3, + }) + defer h.close() + + const n, wn, dur = 10000, 100, 30 * time.Second + + runtime.GOMAXPROCS(runtime.NumCPU()) + + randomData := func(prefix byte, i int) []byte { + data := make([]byte, 1+4+32+64+32) + _, err := crand.Reader.Read(data[1 : len(data)-8]) + if err != nil { + panic(err) + } + data[0] = prefix + binary.LittleEndian.PutUint32(data[len(data)-8:], uint32(i)) + binary.LittleEndian.PutUint32(data[len(data)-4:], util.NewCRC(data[:len(data)-4]).Value()) + return data + } + + keys := make([][]byte, n) + for i := range keys { + keys[i] = randomData(1, 0) + } + + until := time.Now().Add(dur) + wg := new(sync.WaitGroup) + wg.Add(3) + var done uint32 + go func() { + i := 0 + defer func() { + t.Logf("WRITER DONE #%d", i) + wg.Done() + }() + + b := new(Batch) + for ; i < wn && atomic.LoadUint32(&done) == 0; i++ { + b.Reset() + for _, k1 := range keys { + k2 := randomData(2, i) + b.Put(k2, randomData(42, i)) + b.Put(k1, k2) + } + if err := h.db.Write(b, h.wo); err != nil { + atomic.StoreUint32(&done, 1) + t.Fatalf("WRITER #%d db.Write: %v", i, err) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER0 DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + snap := h.getSnapshot() + seq := snap.elem.seq + if seq == 0 { + snap.Release() + continue + } + iter := snap.NewIterator(util.BytesPrefix([]byte{1}), nil) + writei := int(seq/(n*2) - 1) + var k int + for ; iter.Next(); k++ { + k1 := iter.Key() + k2 := iter.Value() + k1checksum0 := binary.LittleEndian.Uint32(k1[len(k1)-4:]) + k1checksum1 := util.NewCRC(k1[:len(k1)-4]).Value() + if k1checksum0 != k1checksum1 { + t.Fatalf("READER0 #%d.%d W#%d invalid K1 checksum: %#x != %#x", i, k, writei, k1checksum0, k1checksum0) + } + k2checksum0 := binary.LittleEndian.Uint32(k2[len(k2)-4:]) + k2checksum1 := util.NewCRC(k2[:len(k2)-4]).Value() + if k2checksum0 != k2checksum1 { + t.Fatalf("READER0 #%d.%d W#%d invalid K2 checksum: %#x != %#x", i, k, writei, k2checksum0, k2checksum1) + } + kwritei := int(binary.LittleEndian.Uint32(k2[len(k2)-8:])) + if writei != kwritei { + t.Fatalf("READER0 #%d.%d W#%d invalid write iteration num: %d", i, k, writei, kwritei) + } + if _, err := snap.Get(k2, nil); err != nil { + t.Fatalf("READER0 #%d.%d W#%d snap.Get: %v\nk1: %x\n -> k2: %x", i, k, writei, err, k1, k2) + } + } + if err := iter.Error(); err != nil { + t.Fatalf("READER0 #%d.%d W#%d snap.Iterator: %v", i, k, writei, err) + } + iter.Release() + snap.Release() + if k > 0 && k != n { + t.Fatalf("READER0 #%d W#%d short read, got=%d want=%d", i, writei, k, n) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER1 DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + iter := h.db.NewIterator(nil, nil) + seq := iter.(*dbIter).seq + if seq == 0 { + iter.Release() + continue + } + writei := int(seq/(n*2) - 1) + var k int + for ok := iter.Last(); ok; ok = iter.Prev() { + k++ + } + if err := iter.Error(); err != nil { + t.Fatalf("READER1 #%d.%d W#%d db.Iterator: %v", i, k, writei, err) + } + iter.Release() + if m := (writei+1)*n + n; k != m { + t.Fatalf("READER1 #%d W#%d short read, got=%d want=%d", i, writei, k, m) + } + } + }() + + wg.Wait() +} + +func TestDB_TransientError(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 128 * opt.KiB, + OpenFilesCacheCapacity: 3, + DisableCompactionBackoff: true, + }) + defer h.close() + + const ( + nSnap = 20 + nKey = 10000 + ) + + var ( + snaps [nSnap]*Snapshot + b = &Batch{} + ) + for i := range snaps { + vtail := fmt.Sprintf("VAL%030d", i) + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%8d", k) + b.Put([]byte(key), []byte(key+vtail)) + } + h.stor.EmulateError(testutil.ModeOpen|testutil.ModeRead, storage.TypeTable, errors.New("table transient read error")) + if err := h.db.Write(b, nil); err != nil { + t.Logf("WRITE #%d error: %v", i, err) + h.stor.EmulateError(testutil.ModeOpen|testutil.ModeRead, storage.TypeTable, nil) + for { + if err := h.db.Write(b, nil); err == nil { + break + } else if errors.IsCorrupted(err) { + t.Fatalf("WRITE #%d corrupted: %v", i, err) + } + } + } + + snaps[i] = h.db.newSnapshot() + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%8d", k) + b.Delete([]byte(key)) + } + h.stor.EmulateError(testutil.ModeOpen|testutil.ModeRead, storage.TypeTable, errors.New("table transient read error")) + if err := h.db.Write(b, nil); err != nil { + t.Logf("WRITE #%d error: %v", i, err) + h.stor.EmulateError(testutil.ModeOpen|testutil.ModeRead, storage.TypeTable, nil) + for { + if err := h.db.Write(b, nil); err == nil { + break + } else if errors.IsCorrupted(err) { + t.Fatalf("WRITE #%d corrupted: %v", i, err) + } + } + } + } + h.stor.EmulateError(testutil.ModeOpen|testutil.ModeRead, storage.TypeTable, nil) + + runtime.GOMAXPROCS(runtime.NumCPU()) + + rnd := rand.New(rand.NewSource(0xecafdaed)) + wg := &sync.WaitGroup{} + for i, snap := range snaps { + wg.Add(2) + + go func(i int, snap *Snapshot, sk []int) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + for _, k := range sk { + key := fmt.Sprintf("KEY%8d", k) + xvalue, err := snap.Get([]byte(key), nil) + if err != nil { + t.Fatalf("READER_GET #%d SEQ=%d K%d error: %v", i, snap.elem.seq, k, err) + } + value := key + vtail + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_GET #%d SEQ=%d K%d invalid value: want %q, got %q", i, snap.elem.seq, k, value, xvalue) + } + } + }(i, snap, rnd.Perm(nKey)) + + go func(i int, snap *Snapshot) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + iter := snap.NewIterator(nil, nil) + defer iter.Release() + for k := 0; k < nKey; k++ { + if !iter.Next() { + if err := iter.Error(); err != nil { + t.Fatalf("READER_ITER #%d K%d error: %v", i, k, err) + } else { + t.Fatalf("READER_ITER #%d K%d eoi", i, k) + } + } + key := fmt.Sprintf("KEY%8d", k) + xkey := iter.Key() + if !bytes.Equal([]byte(key), xkey) { + t.Fatalf("READER_ITER #%d K%d invalid key: want %q, got %q", i, k, key, xkey) + } + value := key + vtail + xvalue := iter.Value() + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_ITER #%d K%d invalid value: want %q, got %q", i, k, value, xvalue) + } + } + }(i, snap) + } + + wg.Wait() +} + +func TestDB_UkeyShouldntHopAcrossTable(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 90 * opt.KiB, + CompactionExpandLimitFactor: 1, + }) + defer h.close() + + const ( + nSnap = 190 + nKey = 140 + ) + + var ( + snaps [nSnap]*Snapshot + b = &Batch{} + ) + for i := range snaps { + vtail := fmt.Sprintf("VAL%030d", i) + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + b.Put([]byte(key), []byte(key+vtail)) + } + if err := h.db.Write(b, nil); err != nil { + t.Fatalf("WRITE #%d error: %v", i, err) + } + + snaps[i] = h.db.newSnapshot() + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + b.Delete([]byte(key)) + } + if err := h.db.Write(b, nil); err != nil { + t.Fatalf("WRITE #%d error: %v", i, err) + } + } + + h.compactMem() + + h.waitCompaction() + for level, tables := range h.db.s.stVersion.levels { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.fd.Num, table.imin, table.imax) + } + } + + h.compactRangeAt(0, "", "") + h.waitCompaction() + for level, tables := range h.db.s.stVersion.levels { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.fd.Num, table.imin, table.imax) + } + } + h.compactRangeAt(1, "", "") + h.waitCompaction() + for level, tables := range h.db.s.stVersion.levels { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.fd.Num, table.imin, table.imax) + } + } + runtime.GOMAXPROCS(runtime.NumCPU()) + + wg := &sync.WaitGroup{} + for i, snap := range snaps { + wg.Add(1) + + go func(i int, snap *Snapshot) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + xvalue, err := snap.Get([]byte(key), nil) + if err != nil { + t.Fatalf("READER_GET #%d SEQ=%d K%d error: %v", i, snap.elem.seq, k, err) + } + value := key + vtail + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_GET #%d SEQ=%d K%d invalid value: want %q, got %q", i, snap.elem.seq, k, value, xvalue) + } + } + }(i, snap) + } + + wg.Wait() +} + +func TestDB_TableCompactionBuilder(t *testing.T) { + gomega.RegisterTestingT(t) + stor := testutil.NewStorage() + stor.OnLog(testingLogger(t)) + stor.OnClose(testingPreserveOnFailed(t)) + defer stor.Close() + + const nSeq = 99 + + o := &opt.Options{ + DisableLargeBatchTransaction: true, + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 43 * opt.KiB, + CompactionExpandLimitFactor: 1, + CompactionGPOverlapsFactor: 1, + DisableBlockCache: true, + } + s, err := newSession(stor, o) + if err != nil { + t.Fatal(err) + } + if err := s.create(); err != nil { + t.Fatal(err) + } + defer s.close() + var ( + seq uint64 + targetSize = 5 * o.CompactionTableSize + value = bytes.Repeat([]byte{'0'}, 100) + ) + for i := 0; i < 2; i++ { + tw, err := s.tops.create() + if err != nil { + t.Fatal(err) + } + for k := 0; tw.tw.BytesLen() < targetSize; k++ { + key := []byte(fmt.Sprintf("%09d", k)) + seq += nSeq - 1 + for x := uint64(0); x < nSeq; x++ { + if err := tw.append(makeInternalKey(nil, key, seq-x, keyTypeVal), value); err != nil { + t.Fatal(err) + } + } + } + tf, err := tw.finish() + if err != nil { + t.Fatal(err) + } + rec := &sessionRecord{} + rec.addTableFile(i, tf) + if err := s.commit(rec, false); err != nil { + t.Fatal(err) + } + } + + // Build grandparent. + v := s.version() + c := newCompaction(s, v, 1, append(tFiles{}, v.levels[1]...), undefinedCompaction) + rec := &sessionRecord{} + b := &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize/3 + 961, + } + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Fatal(err) + } + for _, t := range c.levels[0] { + rec.delTable(c.sourceLevel, t.fd.Num) + } + if err := s.commit(rec, false); err != nil { + t.Fatal(err) + } + c.release() + + // Build level-1. + v = s.version() + c = newCompaction(s, v, 0, append(tFiles{}, v.levels[0]...), undefinedCompaction) + rec = &sessionRecord{} + b = &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize, + } + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Fatal(err) + } + for _, t := range c.levels[0] { + rec.delTable(c.sourceLevel, t.fd.Num) + } + // Move grandparent to level-3 + for _, t := range v.levels[2] { + rec.delTable(2, t.fd.Num) + rec.addTableFile(3, t) + } + if err := s.commit(rec, false); err != nil { + t.Fatal(err) + } + c.release() + + v = s.version() + for level, want := range []bool{false, true, false, true} { + got := len(v.levels[level]) > 0 + if want != got { + t.Fatalf("invalid level-%d tables len: want %v, got %v", level, want, got) + } + } + for i, f := range v.levels[1][:len(v.levels[1])-1] { + nf := v.levels[1][i+1] + if bytes.Equal(f.imax.ukey(), nf.imin.ukey()) { + t.Fatalf("KEY %q hop across table %d .. %d", f.imax.ukey(), f.fd.Num, nf.fd.Num) + } + } + v.release() + + // Compaction with transient error. + v = s.version() + c = newCompaction(s, v, 1, append(tFiles{}, v.levels[1]...), undefinedCompaction) + rec = &sessionRecord{} + b = &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize, + } + stor.EmulateErrorOnce(testutil.ModeSync, storage.TypeTable, errors.New("table sync error (once)")) + stor.EmulateRandomError(testutil.ModeRead|testutil.ModeWrite, storage.TypeTable, 0.01, errors.New("table random IO error")) + for { + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Logf("(expected) b.run: %v", err) + } else { + break + } + } + if err := s.commit(rec, false); err != nil { + t.Fatal(err) + } + c.release() + + stor.EmulateErrorOnce(testutil.ModeSync, storage.TypeTable, nil) + stor.EmulateRandomError(testutil.ModeRead|testutil.ModeWrite, storage.TypeTable, 0, nil) + + v = s.version() + if len(v.levels[1]) != len(v.levels[2]) { + t.Fatalf("invalid tables length, want %d, got %d", len(v.levels[1]), len(v.levels[2])) + } + for i, f0 := range v.levels[1] { + f1 := v.levels[2][i] + iter0 := s.tops.newIterator(f0, nil, nil) + iter1 := s.tops.newIterator(f1, nil, nil) + for j := 0; true; j++ { + next0 := iter0.Next() + next1 := iter1.Next() + if next0 != next1 { + t.Fatalf("#%d.%d invalid eoi: want %v, got %v", i, j, next0, next1) + } + key0 := iter0.Key() + key1 := iter1.Key() + if !bytes.Equal(key0, key1) { + t.Fatalf("#%d.%d invalid key: want %q, got %q", i, j, key0, key1) + } + if next0 == false { + break + } + } + iter0.Release() + iter1.Release() + } + v.release() +} + +func testDB_IterTriggeredCompaction(t *testing.T, limitDiv int) { + const ( + vSize = 200 * opt.KiB + tSize = 100 * opt.MiB + mIter = 100 + n = tSize / vSize + ) + + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + DisableBlockCache: true, + }) + defer h.close() + + h.db.memdbMaxLevel = 2 + + key := func(x int) string { + return fmt.Sprintf("v%06d", x) + } + + // Fill. + value := strings.Repeat("x", vSize) + for i := 0; i < n; i++ { + h.put(key(i), value) + } + h.compactMem() + + // Delete all. + for i := 0; i < n; i++ { + h.delete(key(i)) + } + h.compactMem() + + var ( + limit = n / limitDiv + + startKey = key(0) + limitKey = key(limit) + maxKey = key(n) + slice = &util.Range{Limit: []byte(limitKey)} + + initialSize0 = h.sizeOf(startKey, limitKey) + initialSize1 = h.sizeOf(limitKey, maxKey) + ) + + t.Logf("initial size %s [rest %s]", shortenb(int(initialSize0)), shortenb(int(initialSize1))) + + for r := 0; true; r++ { + if r >= mIter { + t.Fatal("taking too long to compact") + } + + // Iterates. + iter := h.db.NewIterator(slice, h.ro) + for iter.Next() { + } + if err := iter.Error(); err != nil { + t.Fatalf("Iter err: %v", err) + } + iter.Release() + + // Wait compaction. + h.waitCompaction() + + // Check size. + size0 := h.sizeOf(startKey, limitKey) + size1 := h.sizeOf(limitKey, maxKey) + t.Logf("#%03d size %s [rest %s]", r, shortenb(int(size0)), shortenb(int(size1))) + if size0 < initialSize0/10 { + break + } + } + + if initialSize1 > 0 { + h.sizeAssert(limitKey, maxKey, initialSize1/4-opt.MiB, initialSize1+opt.MiB) + } +} + +func TestDB_IterTriggeredCompaction(t *testing.T) { + testDB_IterTriggeredCompaction(t, 1) +} + +func TestDB_IterTriggeredCompactionHalf(t *testing.T) { + testDB_IterTriggeredCompaction(t, 2) +} + +func TestDB_ReadOnly(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "v1") + h.put("bar", "v2") + h.compactMem() + + h.put("xfoo", "v1") + h.put("xbar", "v2") + + t.Log("Trigger read-only") + if err := h.db.SetReadOnly(); err != nil { + h.close() + t.Fatalf("SetReadOnly error: %v", err) + } + + mode := testutil.ModeCreate | testutil.ModeRemove | testutil.ModeRename | testutil.ModeWrite | testutil.ModeSync + h.stor.EmulateError(mode, storage.TypeAll, errors.New("read-only DB shouldn't writes")) + + ro := func(key, value, wantValue string) { + if err := h.db.Put([]byte(key), []byte(value), h.wo); err != ErrReadOnly { + t.Fatalf("unexpected error: %v", err) + } + h.getVal(key, wantValue) + } + + ro("foo", "vx", "v1") + + h.o.ReadOnly = true + h.reopenDB() + + ro("foo", "vx", "v1") + ro("bar", "vx", "v2") + h.assertNumKeys(4) +} + +func TestDB_BulkInsertDelete(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + CompactionTableSize: 128 * opt.KiB, + CompactionTotalSize: 1 * opt.MiB, + WriteBuffer: 256 * opt.KiB, + }) + defer h.close() + + const R = 100 + const N = 2500 + key := make([]byte, 4) + value := make([]byte, 256) + for i := 0; i < R; i++ { + offset := N * i + for j := 0; j < N; j++ { + binary.BigEndian.PutUint32(key, uint32(offset+j)) + h.db.Put(key, value, nil) + } + for j := 0; j < N; j++ { + binary.BigEndian.PutUint32(key, uint32(offset+j)) + h.db.Delete(key, nil) + } + } + + h.waitCompaction() + if tot := h.totalTables(); tot > 10 { + t.Fatalf("too many uncompacted tables: %d (%s)", tot, h.getTablesPerLevel()) + } +} + +func TestDB_GracefulClose(t *testing.T) { + runtime.GOMAXPROCS(4) + h := newDbHarnessWopt(t, &opt.Options{ + DisableLargeBatchTransaction: true, + Compression: opt.NoCompression, + CompactionTableSize: 1 * opt.MiB, + WriteBuffer: 1 * opt.MiB, + }) + defer h.close() + + var closeWait sync.WaitGroup + + // During write. + n := 0 + closing := false + for i := 0; i < 1000000; i++ { + if !closing && h.totalTables() > 3 { + t.Logf("close db during write, index=%d", i) + closeWait.Add(1) + go func() { + h.closeDB() + closeWait.Done() + }() + closing = true + } + if err := h.db.Put([]byte(fmt.Sprintf("%09d", i)), []byte(fmt.Sprintf("VAL-%09d", i)), h.wo); err != nil { + t.Logf("Put error: %s (expected)", err) + n = i + break + } + } + closeWait.Wait() + + // During read. + h.openDB() + closing = false + for i := 0; i < n; i++ { + if !closing && i > n/2 { + t.Logf("close db during read, index=%d", i) + closeWait.Add(1) + go func() { + h.closeDB() + closeWait.Done() + }() + closing = true + } + if _, err := h.db.Get([]byte(fmt.Sprintf("%09d", i)), h.ro); err != nil { + t.Logf("Get error: %s (expected)", err) + break + } + } + closeWait.Wait() + + // During iterate. + h.openDB() + closing = false + iter := h.db.NewIterator(nil, h.ro) + for i := 0; iter.Next(); i++ { + if len(iter.Key()) == 0 || len(iter.Value()) == 0 { + t.Error("Key or value has zero length") + } + if !closing { + t.Logf("close db during iter, index=%d", i) + closeWait.Add(1) + go func() { + h.closeDB() + closeWait.Done() + }() + closing = true + } + time.Sleep(time.Millisecond) + } + if err := iter.Error(); err != nil { + t.Logf("Iter error: %s (expected)", err) + } + iter.Release() + closeWait.Wait() +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go new file mode 100644 index 0000000..21d1e51 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go @@ -0,0 +1,335 @@ +// Copyright (c) 2016, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "errors" + "sync" + "time" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var errTransactionDone = errors.New("leveldb: transaction already closed") + +// Transaction is the transaction handle. +type Transaction struct { + db *DB + lk sync.RWMutex + seq uint64 + mem *memDB + tables tFiles + ikScratch []byte + rec sessionRecord + stats cStatStaging + closed bool +} + +// Get gets the value for the given key. It returns ErrNotFound if the +// DB does not contains the key. +// +// The returned slice is its own copy, it is safe to modify the contents +// of the returned slice. +// It is safe to modify the contents of the argument after Get returns. +func (tr *Transaction) Get(key []byte, ro *opt.ReadOptions) ([]byte, error) { + tr.lk.RLock() + defer tr.lk.RUnlock() + if tr.closed { + return nil, errTransactionDone + } + return tr.db.get(tr.mem.DB, tr.tables, key, tr.seq, ro) +} + +// Has returns true if the DB does contains the given key. +// +// It is safe to modify the contents of the argument after Has returns. +func (tr *Transaction) Has(key []byte, ro *opt.ReadOptions) (bool, error) { + tr.lk.RLock() + defer tr.lk.RUnlock() + if tr.closed { + return false, errTransactionDone + } + return tr.db.has(tr.mem.DB, tr.tables, key, tr.seq, ro) +} + +// NewIterator returns an iterator for the latest snapshot of the transaction. +// The returned iterator is not safe for concurrent use, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently while writes to the +// transaction. The resultant key/value pairs are guaranteed to be consistent. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// The returned iterator has locks on its own resources, so it can live beyond +// the lifetime of the transaction who creates them. +// +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// +// The iterator must be released after use, by calling Release method. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (tr *Transaction) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + tr.lk.RLock() + defer tr.lk.RUnlock() + if tr.closed { + return iterator.NewEmptyIterator(errTransactionDone) + } + tr.mem.incref() + return tr.db.newIterator(tr.mem, tr.tables, tr.seq, slice, ro) +} + +func (tr *Transaction) flush() error { + // Flush memdb. + if tr.mem.Len() != 0 { + tr.stats.startTimer() + iter := tr.mem.NewIterator(nil) + t, n, err := tr.db.s.tops.createFrom(iter) + iter.Release() + tr.stats.stopTimer() + if err != nil { + return err + } + if tr.mem.getref() == 1 { + tr.mem.Reset() + } else { + tr.mem.decref() + tr.mem = tr.db.mpoolGet(0) + tr.mem.incref() + } + tr.tables = append(tr.tables, t) + tr.rec.addTableFile(0, t) + tr.stats.write += t.size + tr.db.logf("transaction@flush created L0@%d N·%d S·%s %q:%q", t.fd.Num, n, shortenb(int(t.size)), t.imin, t.imax) + } + return nil +} + +func (tr *Transaction) put(kt keyType, key, value []byte) error { + tr.ikScratch = makeInternalKey(tr.ikScratch, key, tr.seq+1, kt) + if tr.mem.Free() < len(tr.ikScratch)+len(value) { + if err := tr.flush(); err != nil { + return err + } + } + if err := tr.mem.Put(tr.ikScratch, value); err != nil { + return err + } + tr.seq++ + return nil +} + +// Put sets the value for the given key. It overwrites any previous value +// for that key; a DB is not a multi-map. +// Please note that the transaction is not compacted until committed, so if you +// writes 10 same keys, then those 10 same keys are in the transaction. +// +// It is safe to modify the contents of the arguments after Put returns. +func (tr *Transaction) Put(key, value []byte, wo *opt.WriteOptions) error { + tr.lk.Lock() + defer tr.lk.Unlock() + if tr.closed { + return errTransactionDone + } + return tr.put(keyTypeVal, key, value) +} + +// Delete deletes the value for the given key. +// Please note that the transaction is not compacted until committed, so if you +// writes 10 same keys, then those 10 same keys are in the transaction. +// +// It is safe to modify the contents of the arguments after Delete returns. +func (tr *Transaction) Delete(key []byte, wo *opt.WriteOptions) error { + tr.lk.Lock() + defer tr.lk.Unlock() + if tr.closed { + return errTransactionDone + } + return tr.put(keyTypeDel, key, nil) +} + +// Write apply the given batch to the transaction. The batch will be applied +// sequentially. +// Please note that the transaction is not compacted until committed, so if you +// writes 10 same keys, then those 10 same keys are in the transaction. +// +// It is safe to modify the contents of the arguments after Write returns. +func (tr *Transaction) Write(b *Batch, wo *opt.WriteOptions) error { + if b == nil || b.Len() == 0 { + return nil + } + + tr.lk.Lock() + defer tr.lk.Unlock() + if tr.closed { + return errTransactionDone + } + return b.replayInternal(func(i int, kt keyType, k, v []byte) error { + return tr.put(kt, k, v) + }) +} + +func (tr *Transaction) setDone() { + tr.closed = true + tr.db.tr = nil + tr.mem.decref() + <-tr.db.writeLockC +} + +// Commit commits the transaction. If error is not nil, then the transaction is +// not committed, it can then either be retried or discarded. +// +// Other methods should not be called after transaction has been committed. +func (tr *Transaction) Commit() error { + if err := tr.db.ok(); err != nil { + return err + } + + tr.lk.Lock() + defer tr.lk.Unlock() + if tr.closed { + return errTransactionDone + } + if err := tr.flush(); err != nil { + // Return error, lets user decide either to retry or discard + // transaction. + return err + } + if len(tr.tables) != 0 { + // Committing transaction. + tr.rec.setSeqNum(tr.seq) + tr.db.compCommitLk.Lock() + tr.stats.startTimer() + var cerr error + for retry := 0; retry < 3; retry++ { + cerr = tr.db.s.commit(&tr.rec, false) + if cerr != nil { + tr.db.logf("transaction@commit error R·%d %q", retry, cerr) + select { + case <-time.After(time.Second): + case <-tr.db.closeC: + tr.db.logf("transaction@commit exiting") + tr.db.compCommitLk.Unlock() + return cerr + } + } else { + // Success. Set db.seq. + tr.db.setSeq(tr.seq) + break + } + } + tr.stats.stopTimer() + if cerr != nil { + // Return error, lets user decide either to retry or discard + // transaction. + return cerr + } + + // Update compaction stats. This is safe as long as we hold compCommitLk. + tr.db.compStats.addStat(0, &tr.stats) + + // Trigger table auto-compaction. + tr.db.compTrigger(tr.db.tcompCmdC) + tr.db.compCommitLk.Unlock() + + // Additionally, wait compaction when certain threshold reached. + // Ignore error, returns error only if transaction can't be committed. + tr.db.waitCompaction() + } + // Only mark as done if transaction committed successfully. + tr.setDone() + return nil +} + +func (tr *Transaction) discard() { + // Discard transaction. + for _, t := range tr.tables { + tr.db.logf("transaction@discard @%d", t.fd.Num) + // Iterator may still use the table, so we use tOps.remove here. + tr.db.s.tops.remove(t.fd) + } +} + +// Discard discards the transaction. +// This method is noop if transaction is already closed (either committed or +// discarded) +// +// Other methods should not be called after transaction has been discarded. +func (tr *Transaction) Discard() { + tr.lk.Lock() + if !tr.closed { + tr.discard() + tr.setDone() + } + tr.lk.Unlock() +} + +func (db *DB) waitCompaction() error { + if db.s.tLen(0) >= db.s.o.GetWriteL0PauseTrigger() { + return db.compTriggerWait(db.tcompCmdC) + } + return nil +} + +// OpenTransaction opens an atomic DB transaction. Only one transaction can be +// opened at a time. Subsequent call to Write and OpenTransaction will be blocked +// until in-flight transaction is committed or discarded. +// The returned transaction handle is safe for concurrent use. +// +// Transaction is very expensive and can overwhelm compaction, especially if +// transaction size is small. Use with caution. +// The rule of thumb is if you need to merge at least same amount of +// `Options.WriteBuffer` worth of data then use transaction, otherwise don't. +// +// The transaction must be closed once done, either by committing or discarding +// the transaction. +// Closing the DB will discard open transaction. +func (db *DB) OpenTransaction() (*Transaction, error) { + if err := db.ok(); err != nil { + return nil, err + } + + // The write happen synchronously. + select { + case db.writeLockC <- struct{}{}: + case err := <-db.compPerErrC: + return nil, err + case <-db.closeC: + return nil, ErrClosed + } + + if db.tr != nil { + panic("leveldb: has open transaction") + } + + // Flush current memdb. + if db.mem != nil && db.mem.Len() != 0 { + if _, err := db.rotateMem(0, true); err != nil { + return nil, err + } + } + + // Wait compaction when certain threshold reached. + if err := db.waitCompaction(); err != nil { + return nil, err + } + + tr := &Transaction{ + db: db, + seq: db.seq, + mem: db.mpoolGet(0), + } + tr.mem.incref() + db.tr = tr + return tr, nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_util.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_util.go new file mode 100644 index 0000000..3f06548 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_util.go @@ -0,0 +1,102 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Reader is the interface that wraps basic Get and NewIterator methods. +// This interface implemented by both DB and Snapshot. +type Reader interface { + Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) + NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator +} + +// Sizes is list of size. +type Sizes []int64 + +// Sum returns sum of the sizes. +func (sizes Sizes) Sum() int64 { + var sum int64 + for _, size := range sizes { + sum += size + } + return sum +} + +// Logging. +func (db *DB) log(v ...interface{}) { db.s.log(v...) } +func (db *DB) logf(format string, v ...interface{}) { db.s.logf(format, v...) } + +// Check and clean files. +func (db *DB) checkAndCleanFiles() error { + v := db.s.version() + defer v.release() + + tmap := make(map[int64]bool) + for _, tables := range v.levels { + for _, t := range tables { + tmap[t.fd.Num] = false + } + } + + fds, err := db.s.stor.List(storage.TypeAll) + if err != nil { + return err + } + + var nt int + var rem []storage.FileDesc + for _, fd := range fds { + keep := true + switch fd.Type { + case storage.TypeManifest: + keep = fd.Num >= db.s.manifestFd.Num + case storage.TypeJournal: + if !db.frozenJournalFd.Zero() { + keep = fd.Num >= db.frozenJournalFd.Num + } else { + keep = fd.Num >= db.journalFd.Num + } + case storage.TypeTable: + _, keep = tmap[fd.Num] + if keep { + tmap[fd.Num] = true + nt++ + } + } + + if !keep { + rem = append(rem, fd) + } + } + + if nt != len(tmap) { + var mfds []storage.FileDesc + for num, present := range tmap { + if !present { + mfds = append(mfds, storage.FileDesc{Type: storage.TypeTable, Num: num}) + db.logf("db@janitor table missing @%d", num) + } + } + return errors.NewErrCorrupted(storage.FileDesc{}, &errors.ErrMissingFiles{Fds: mfds}) + } + + db.logf("db@janitor F·%d G·%d", len(fds), len(rem)) + for _, fd := range rem { + db.logf("db@janitor removing %s-%d", fd.Type, fd.Num) + if err := db.s.stor.Remove(fd); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_write.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_write.go new file mode 100644 index 0000000..db0c1be --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_write.go @@ -0,0 +1,464 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func (db *DB) writeJournal(batches []*Batch, seq uint64, sync bool) error { + wr, err := db.journal.Next() + if err != nil { + return err + } + if err := writeBatchesWithHeader(wr, batches, seq); err != nil { + return err + } + if err := db.journal.Flush(); err != nil { + return err + } + if sync { + return db.journalWriter.Sync() + } + return nil +} + +func (db *DB) rotateMem(n int, wait bool) (mem *memDB, err error) { + retryLimit := 3 +retry: + // Wait for pending memdb compaction. + err = db.compTriggerWait(db.mcompCmdC) + if err != nil { + return + } + retryLimit-- + + // Create new memdb and journal. + mem, err = db.newMem(n) + if err != nil { + if err == errHasFrozenMem { + if retryLimit <= 0 { + panic("BUG: still has frozen memdb") + } + goto retry + } + return + } + + // Schedule memdb compaction. + if wait { + err = db.compTriggerWait(db.mcompCmdC) + } else { + db.compTrigger(db.mcompCmdC) + } + return +} + +func (db *DB) flush(n int) (mdb *memDB, mdbFree int, err error) { + delayed := false + slowdownTrigger := db.s.o.GetWriteL0SlowdownTrigger() + pauseTrigger := db.s.o.GetWriteL0PauseTrigger() + flush := func() (retry bool) { + mdb = db.getEffectiveMem() + if mdb == nil { + err = ErrClosed + return false + } + defer func() { + if retry { + mdb.decref() + mdb = nil + } + }() + tLen := db.s.tLen(0) + mdbFree = mdb.Free() + switch { + case tLen >= slowdownTrigger && !delayed: + delayed = true + time.Sleep(time.Millisecond) + case mdbFree >= n: + return false + case tLen >= pauseTrigger: + delayed = true + // Set the write paused flag explicitly. + atomic.StoreInt32(&db.inWritePaused, 1) + err = db.compTriggerWait(db.tcompCmdC) + // Unset the write paused flag. + atomic.StoreInt32(&db.inWritePaused, 0) + if err != nil { + return false + } + default: + // Allow memdb to grow if it has no entry. + if mdb.Len() == 0 { + mdbFree = n + } else { + mdb.decref() + mdb, err = db.rotateMem(n, false) + if err == nil { + mdbFree = mdb.Free() + } else { + mdbFree = 0 + } + } + return false + } + return true + } + start := time.Now() + for flush() { + } + if delayed { + db.writeDelay += time.Since(start) + db.writeDelayN++ + } else if db.writeDelayN > 0 { + db.logf("db@write was delayed N·%d T·%v", db.writeDelayN, db.writeDelay) + atomic.AddInt32(&db.cWriteDelayN, int32(db.writeDelayN)) + atomic.AddInt64(&db.cWriteDelay, int64(db.writeDelay)) + db.writeDelay = 0 + db.writeDelayN = 0 + } + return +} + +type writeMerge struct { + sync bool + batch *Batch + keyType keyType + key, value []byte +} + +func (db *DB) unlockWrite(overflow bool, merged int, err error) { + for i := 0; i < merged; i++ { + db.writeAckC <- err + } + if overflow { + // Pass lock to the next write (that failed to merge). + db.writeMergedC <- false + } else { + // Release lock. + <-db.writeLockC + } +} + +// ourBatch is batch that we can modify. +func (db *DB) writeLocked(batch, ourBatch *Batch, merge, sync bool) error { + // Try to flush memdb. This method would also trying to throttle writes + // if it is too fast and compaction cannot catch-up. + mdb, mdbFree, err := db.flush(batch.internalLen) + if err != nil { + db.unlockWrite(false, 0, err) + return err + } + defer mdb.decref() + + var ( + overflow bool + merged int + batches = []*Batch{batch} + ) + + if merge { + // Merge limit. + var mergeLimit int + if batch.internalLen > 128<<10 { + mergeLimit = (1 << 20) - batch.internalLen + } else { + mergeLimit = 128 << 10 + } + mergeCap := mdbFree - batch.internalLen + if mergeLimit > mergeCap { + mergeLimit = mergeCap + } + + merge: + for mergeLimit > 0 { + select { + case incoming := <-db.writeMergeC: + if incoming.batch != nil { + // Merge batch. + if incoming.batch.internalLen > mergeLimit { + overflow = true + break merge + } + batches = append(batches, incoming.batch) + mergeLimit -= incoming.batch.internalLen + } else { + // Merge put. + internalLen := len(incoming.key) + len(incoming.value) + 8 + if internalLen > mergeLimit { + overflow = true + break merge + } + if ourBatch == nil { + ourBatch = db.batchPool.Get().(*Batch) + ourBatch.Reset() + batches = append(batches, ourBatch) + } + // We can use same batch since concurrent write doesn't + // guarantee write order. + ourBatch.appendRec(incoming.keyType, incoming.key, incoming.value) + mergeLimit -= internalLen + } + sync = sync || incoming.sync + merged++ + db.writeMergedC <- true + + default: + break merge + } + } + } + + // Release ourBatch if any. + if ourBatch != nil { + defer db.batchPool.Put(ourBatch) + } + + // Seq number. + seq := db.seq + 1 + + // Write journal. + if err := db.writeJournal(batches, seq, sync); err != nil { + db.unlockWrite(overflow, merged, err) + return err + } + + // Put batches. + for _, batch := range batches { + if err := batch.putMem(seq, mdb.DB); err != nil { + panic(err) + } + seq += uint64(batch.Len()) + } + + // Incr seq number. + db.addSeq(uint64(batchesLen(batches))) + + // Rotate memdb if it's reach the threshold. + if batch.internalLen >= mdbFree { + db.rotateMem(0, false) + } + + db.unlockWrite(overflow, merged, nil) + return nil +} + +// Write apply the given batch to the DB. The batch records will be applied +// sequentially. Write might be used concurrently, when used concurrently and +// batch is small enough, write will try to merge the batches. Set NoWriteMerge +// option to true to disable write merge. +// +// It is safe to modify the contents of the arguments after Write returns but +// not before. Write will not modify content of the batch. +func (db *DB) Write(batch *Batch, wo *opt.WriteOptions) error { + if err := db.ok(); err != nil || batch == nil || batch.Len() == 0 { + return err + } + + // If the batch size is larger than write buffer, it may justified to write + // using transaction instead. Using transaction the batch will be written + // into tables directly, skipping the journaling. + if batch.internalLen > db.s.o.GetWriteBuffer() && !db.s.o.GetDisableLargeBatchTransaction() { + tr, err := db.OpenTransaction() + if err != nil { + return err + } + if err := tr.Write(batch, wo); err != nil { + tr.Discard() + return err + } + return tr.Commit() + } + + merge := !wo.GetNoWriteMerge() && !db.s.o.GetNoWriteMerge() + sync := wo.GetSync() && !db.s.o.GetNoSync() + + // Acquire write lock. + if merge { + select { + case db.writeMergeC <- writeMerge{sync: sync, batch: batch}: + if <-db.writeMergedC { + // Write is merged. + return <-db.writeAckC + } + // Write is not merged, the write lock is handed to us. Continue. + case db.writeLockC <- struct{}{}: + // Write lock acquired. + case err := <-db.compPerErrC: + // Compaction error. + return err + case <-db.closeC: + // Closed + return ErrClosed + } + } else { + select { + case db.writeLockC <- struct{}{}: + // Write lock acquired. + case err := <-db.compPerErrC: + // Compaction error. + return err + case <-db.closeC: + // Closed + return ErrClosed + } + } + + return db.writeLocked(batch, nil, merge, sync) +} + +func (db *DB) putRec(kt keyType, key, value []byte, wo *opt.WriteOptions) error { + if err := db.ok(); err != nil { + return err + } + + merge := !wo.GetNoWriteMerge() && !db.s.o.GetNoWriteMerge() + sync := wo.GetSync() && !db.s.o.GetNoSync() + + // Acquire write lock. + if merge { + select { + case db.writeMergeC <- writeMerge{sync: sync, keyType: kt, key: key, value: value}: + if <-db.writeMergedC { + // Write is merged. + return <-db.writeAckC + } + // Write is not merged, the write lock is handed to us. Continue. + case db.writeLockC <- struct{}{}: + // Write lock acquired. + case err := <-db.compPerErrC: + // Compaction error. + return err + case <-db.closeC: + // Closed + return ErrClosed + } + } else { + select { + case db.writeLockC <- struct{}{}: + // Write lock acquired. + case err := <-db.compPerErrC: + // Compaction error. + return err + case <-db.closeC: + // Closed + return ErrClosed + } + } + + batch := db.batchPool.Get().(*Batch) + batch.Reset() + batch.appendRec(kt, key, value) + return db.writeLocked(batch, batch, merge, sync) +} + +// Put sets the value for the given key. It overwrites any previous value +// for that key; a DB is not a multi-map. Write merge also applies for Put, see +// Write. +// +// It is safe to modify the contents of the arguments after Put returns but not +// before. +func (db *DB) Put(key, value []byte, wo *opt.WriteOptions) error { + return db.putRec(keyTypeVal, key, value, wo) +} + +// Delete deletes the value for the given key. Delete will not returns error if +// key doesn't exist. Write merge also applies for Delete, see Write. +// +// It is safe to modify the contents of the arguments after Delete returns but +// not before. +func (db *DB) Delete(key []byte, wo *opt.WriteOptions) error { + return db.putRec(keyTypeDel, key, nil, wo) +} + +func isMemOverlaps(icmp *iComparer, mem *memdb.DB, min, max []byte) bool { + iter := mem.NewIterator(nil) + defer iter.Release() + return (max == nil || (iter.First() && icmp.uCompare(max, internalKey(iter.Key()).ukey()) >= 0)) && + (min == nil || (iter.Last() && icmp.uCompare(min, internalKey(iter.Key()).ukey()) <= 0)) +} + +// CompactRange compacts the underlying DB for the given key range. +// In particular, deleted and overwritten versions are discarded, +// and the data is rearranged to reduce the cost of operations +// needed to access the data. This operation should typically only +// be invoked by users who understand the underlying implementation. +// +// A nil Range.Start is treated as a key before all keys in the DB. +// And a nil Range.Limit is treated as a key after all keys in the DB. +// Therefore if both is nil then it will compact entire DB. +func (db *DB) CompactRange(r util.Range) error { + if err := db.ok(); err != nil { + return err + } + + // Lock writer. + select { + case db.writeLockC <- struct{}{}: + case err := <-db.compPerErrC: + return err + case <-db.closeC: + return ErrClosed + } + + // Check for overlaps in memdb. + mdb := db.getEffectiveMem() + if mdb == nil { + return ErrClosed + } + defer mdb.decref() + if isMemOverlaps(db.s.icmp, mdb.DB, r.Start, r.Limit) { + // Memdb compaction. + if _, err := db.rotateMem(0, false); err != nil { + <-db.writeLockC + return err + } + <-db.writeLockC + if err := db.compTriggerWait(db.mcompCmdC); err != nil { + return err + } + } else { + <-db.writeLockC + } + + // Table compaction. + return db.compTriggerRange(db.tcompCmdC, -1, r.Start, r.Limit) +} + +// SetReadOnly makes DB read-only. It will stay read-only until reopened. +func (db *DB) SetReadOnly() error { + if err := db.ok(); err != nil { + return err + } + + // Lock writer. + select { + case db.writeLockC <- struct{}{}: + db.compWriteLocking = true + case err := <-db.compPerErrC: + return err + case <-db.closeC: + return ErrClosed + } + + // Set compaction read-only. + select { + case db.compErrSetC <- ErrReadOnly: + case perr := <-db.compPerErrC: + return perr + case <-db.closeC: + return ErrClosed + } + + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/doc.go b/vendor/github.com/syndtr/goleveldb/leveldb/doc.go new file mode 100644 index 0000000..be768e5 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/doc.go @@ -0,0 +1,92 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package leveldb provides implementation of LevelDB key/value database. +// +// Create or open a database: +// +// // The returned DB instance is safe for concurrent use. Which mean that all +// // DB's methods may be called concurrently from multiple goroutine. +// db, err := leveldb.OpenFile("path/to/db", nil) +// ... +// defer db.Close() +// ... +// +// Read or modify the database content: +// +// // Remember that the contents of the returned slice should not be modified. +// data, err := db.Get([]byte("key"), nil) +// ... +// err = db.Put([]byte("key"), []byte("value"), nil) +// ... +// err = db.Delete([]byte("key"), nil) +// ... +// +// Iterate over database content: +// +// iter := db.NewIterator(nil, nil) +// for iter.Next() { +// // Remember that the contents of the returned slice should not be modified, and +// // only valid until the next call to Next. +// key := iter.Key() +// value := iter.Value() +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Iterate over subset of database content with a particular prefix: +// iter := db.NewIterator(util.BytesPrefix([]byte("foo-")), nil) +// for iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Seek-then-Iterate: +// +// iter := db.NewIterator(nil, nil) +// for ok := iter.Seek(key); ok; ok = iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Iterate over subset of database content: +// +// iter := db.NewIterator(&util.Range{Start: []byte("foo"), Limit: []byte("xoo")}, nil) +// for iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Batch writes: +// +// batch := new(leveldb.Batch) +// batch.Put([]byte("foo"), []byte("value")) +// batch.Put([]byte("bar"), []byte("another value")) +// batch.Delete([]byte("baz")) +// err = db.Write(batch, nil) +// ... +// +// Use bloom filter: +// +// o := &opt.Options{ +// Filter: filter.NewBloomFilter(10), +// } +// db, err := leveldb.OpenFile("path/to/db", o) +// ... +// defer db.Close() +// ... +package leveldb diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/errors.go b/vendor/github.com/syndtr/goleveldb/leveldb/errors.go new file mode 100644 index 0000000..de26498 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/errors.go @@ -0,0 +1,20 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" +) + +// Common errors. +var ( + ErrNotFound = errors.ErrNotFound + ErrReadOnly = errors.New("leveldb: read-only mode") + ErrSnapshotReleased = errors.New("leveldb: snapshot released") + ErrIterReleased = errors.New("leveldb: iterator released") + ErrClosed = errors.New("leveldb: closed") +) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/errors/errors.go b/vendor/github.com/syndtr/goleveldb/leveldb/errors/errors.go new file mode 100644 index 0000000..8d6146b --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/errors/errors.go @@ -0,0 +1,78 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package errors provides common error types used throughout leveldb. +package errors + +import ( + "errors" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Common errors. +var ( + ErrNotFound = New("leveldb: not found") + ErrReleased = util.ErrReleased + ErrHasReleaser = util.ErrHasReleaser +) + +// New returns an error that formats as the given text. +func New(text string) error { + return errors.New(text) +} + +// ErrCorrupted is the type that wraps errors that indicate corruption in +// the database. +type ErrCorrupted struct { + Fd storage.FileDesc + Err error +} + +func (e *ErrCorrupted) Error() string { + if !e.Fd.Zero() { + return fmt.Sprintf("%v [file=%v]", e.Err, e.Fd) + } + return e.Err.Error() +} + +// NewErrCorrupted creates new ErrCorrupted error. +func NewErrCorrupted(fd storage.FileDesc, err error) error { + return &ErrCorrupted{fd, err} +} + +// IsCorrupted returns a boolean indicating whether the error is indicating +// a corruption. +func IsCorrupted(err error) bool { + switch err.(type) { + case *ErrCorrupted: + return true + case *storage.ErrCorrupted: + return true + } + return false +} + +// ErrMissingFiles is the type that indicating a corruption due to missing +// files. ErrMissingFiles always wrapped with ErrCorrupted. +type ErrMissingFiles struct { + Fds []storage.FileDesc +} + +func (e *ErrMissingFiles) Error() string { return "file missing" } + +// SetFd sets 'file info' of the given error with the given file. +// Currently only ErrCorrupted is supported, otherwise will do nothing. +func SetFd(err error, fd storage.FileDesc) error { + switch x := err.(type) { + case *ErrCorrupted: + x.Fd = fd + return x + } + return err +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/external_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/external_test.go new file mode 100644 index 0000000..669d77f --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/external_test.go @@ -0,0 +1,117 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Leveldb external", func() { + o := &opt.Options{ + DisableBlockCache: true, + BlockRestartInterval: 5, + BlockSize: 80, + Compression: opt.NoCompression, + OpenFilesCacheCapacity: -1, + Strict: opt.StrictAll, + WriteBuffer: 1000, + CompactionTableSize: 2000, + } + + Describe("write test", func() { + It("should do write correctly", func(done Done) { + db := newTestingDB(o, nil, nil) + t := testutil.DBTesting{ + DB: db, + Deleted: testutil.KeyValue_Generate(nil, 500, 1, 1, 50, 5, 5).Clone(), + } + testutil.DoDBTesting(&t) + db.TestClose() + done <- true + }, 80.0) + }) + + Describe("read test", func() { + testutil.AllKeyValueTesting(nil, nil, func(kv testutil.KeyValue) testutil.DB { + // Building the DB. + db := newTestingDB(o, nil, nil) + kv.IterateShuffled(nil, func(i int, key, value []byte) { + err := db.TestPut(key, value) + Expect(err).NotTo(HaveOccurred()) + }) + + return db + }, func(db testutil.DB) { + db.(*testingDB).TestClose() + }) + }) + + Describe("transaction test", func() { + It("should do transaction correctly", func(done Done) { + db := newTestingDB(o, nil, nil) + + By("creating first transaction") + var err error + tr := &testingTransaction{} + tr.Transaction, err = db.OpenTransaction() + Expect(err).NotTo(HaveOccurred()) + t0 := &testutil.DBTesting{ + DB: tr, + Deleted: testutil.KeyValue_Generate(nil, 200, 1, 1, 50, 5, 5).Clone(), + } + testutil.DoDBTesting(t0) + testutil.TestGet(tr, t0.Present) + testutil.TestHas(tr, t0.Present) + + By("committing first transaction") + err = tr.Commit() + Expect(err).NotTo(HaveOccurred()) + testutil.TestIter(db, nil, t0.Present) + testutil.TestGet(db, t0.Present) + testutil.TestHas(db, t0.Present) + + By("manipulating DB without transaction") + t0.DB = db + testutil.DoDBTesting(t0) + + By("creating second transaction") + tr.Transaction, err = db.OpenTransaction() + Expect(err).NotTo(HaveOccurred()) + t1 := &testutil.DBTesting{ + DB: tr, + Deleted: t0.Deleted.Clone(), + Present: t0.Present.Clone(), + } + testutil.DoDBTesting(t1) + testutil.TestIter(db, nil, t0.Present) + + By("discarding second transaction") + tr.Discard() + testutil.TestIter(db, nil, t0.Present) + + By("creating third transaction") + tr.Transaction, err = db.OpenTransaction() + Expect(err).NotTo(HaveOccurred()) + t0.DB = tr + testutil.DoDBTesting(t0) + + By("committing third transaction") + err = tr.Commit() + Expect(err).NotTo(HaveOccurred()) + testutil.TestIter(db, nil, t0.Present) + + db.TestClose() + done <- true + }, 240.0) + }) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/filter.go b/vendor/github.com/syndtr/goleveldb/leveldb/filter.go new file mode 100644 index 0000000..e961e42 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/filter.go @@ -0,0 +1,31 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/filter" +) + +type iFilter struct { + filter.Filter +} + +func (f iFilter) Contains(filter, key []byte) bool { + return f.Filter.Contains(filter, internalKey(key).ukey()) +} + +func (f iFilter) NewGenerator() filter.FilterGenerator { + return iFilterGenerator{f.Filter.NewGenerator()} +} + +type iFilterGenerator struct { + filter.FilterGenerator +} + +func (g iFilterGenerator) Add(key []byte) { + g.FilterGenerator.Add(internalKey(key).ukey()) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go new file mode 100644 index 0000000..56ccbfb --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go @@ -0,0 +1,116 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package filter + +import ( + "github.com/syndtr/goleveldb/leveldb/util" +) + +func bloomHash(key []byte) uint32 { + return util.Hash(key, 0xbc9f1d34) +} + +type bloomFilter int + +// Name: The bloom filter serializes its parameters and is backward compatible +// with respect to them. Therefor, its parameters are not added to its +// name. +func (bloomFilter) Name() string { + return "leveldb.BuiltinBloomFilter" +} + +func (f bloomFilter) Contains(filter, key []byte) bool { + nBytes := len(filter) - 1 + if nBytes < 1 { + return false + } + nBits := uint32(nBytes * 8) + + // Use the encoded k so that we can read filters generated by + // bloom filters created using different parameters. + k := filter[nBytes] + if k > 30 { + // Reserved for potentially new encodings for short bloom filters. + // Consider it a match. + return true + } + + kh := bloomHash(key) + delta := (kh >> 17) | (kh << 15) // Rotate right 17 bits + for j := uint8(0); j < k; j++ { + bitpos := kh % nBits + if (uint32(filter[bitpos/8]) & (1 << (bitpos % 8))) == 0 { + return false + } + kh += delta + } + return true +} + +func (f bloomFilter) NewGenerator() FilterGenerator { + // Round down to reduce probing cost a little bit. + k := uint8(f * 69 / 100) // 0.69 =~ ln(2) + if k < 1 { + k = 1 + } else if k > 30 { + k = 30 + } + return &bloomFilterGenerator{ + n: int(f), + k: k, + } +} + +type bloomFilterGenerator struct { + n int + k uint8 + + keyHashes []uint32 +} + +func (g *bloomFilterGenerator) Add(key []byte) { + // Use double-hashing to generate a sequence of hash values. + // See analysis in [Kirsch,Mitzenmacher 2006]. + g.keyHashes = append(g.keyHashes, bloomHash(key)) +} + +func (g *bloomFilterGenerator) Generate(b Buffer) { + // Compute bloom filter size (in both bits and bytes) + nBits := uint32(len(g.keyHashes) * g.n) + // For small n, we can see a very high false positive rate. Fix it + // by enforcing a minimum bloom filter length. + if nBits < 64 { + nBits = 64 + } + nBytes := (nBits + 7) / 8 + nBits = nBytes * 8 + + dest := b.Alloc(int(nBytes) + 1) + dest[nBytes] = g.k + for _, kh := range g.keyHashes { + delta := (kh >> 17) | (kh << 15) // Rotate right 17 bits + for j := uint8(0); j < g.k; j++ { + bitpos := kh % nBits + dest[bitpos/8] |= (1 << (bitpos % 8)) + kh += delta + } + } + + g.keyHashes = g.keyHashes[:0] +} + +// NewBloomFilter creates a new initialized bloom filter for given +// bitsPerKey. +// +// Since bitsPerKey is persisted individually for each bloom filter +// serialization, bloom filters are backwards compatible with respect to +// changing bitsPerKey. This means that no big performance penalty will +// be experienced when changing the parameter. See documentation for +// opt.Options.Filter for more information. +func NewBloomFilter(bitsPerKey int) Filter { + return bloomFilter(bitsPerKey) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go new file mode 100644 index 0000000..1fb56f0 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go @@ -0,0 +1,142 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package filter + +import ( + "encoding/binary" + "github.com/syndtr/goleveldb/leveldb/util" + "testing" +) + +type harness struct { + t *testing.T + + bloom Filter + generator FilterGenerator + filter []byte +} + +func newHarness(t *testing.T) *harness { + bloom := NewBloomFilter(10) + return &harness{ + t: t, + bloom: bloom, + generator: bloom.NewGenerator(), + } +} + +func (h *harness) add(key []byte) { + h.generator.Add(key) +} + +func (h *harness) addNum(key uint32) { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], key) + h.add(b[:]) +} + +func (h *harness) build() { + b := &util.Buffer{} + h.generator.Generate(b) + h.filter = b.Bytes() +} + +func (h *harness) reset() { + h.filter = nil +} + +func (h *harness) filterLen() int { + return len(h.filter) +} + +func (h *harness) assert(key []byte, want, silent bool) bool { + got := h.bloom.Contains(h.filter, key) + if !silent && got != want { + h.t.Errorf("assert on '%v' failed got '%v', want '%v'", key, got, want) + } + return got +} + +func (h *harness) assertNum(key uint32, want, silent bool) bool { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], key) + return h.assert(b[:], want, silent) +} + +func TestBloomFilter_Empty(t *testing.T) { + h := newHarness(t) + h.build() + h.assert([]byte("hello"), false, false) + h.assert([]byte("world"), false, false) +} + +func TestBloomFilter_Small(t *testing.T) { + h := newHarness(t) + h.add([]byte("hello")) + h.add([]byte("world")) + h.build() + h.assert([]byte("hello"), true, false) + h.assert([]byte("world"), true, false) + h.assert([]byte("x"), false, false) + h.assert([]byte("foo"), false, false) +} + +func nextN(n int) int { + switch { + case n < 10: + n += 1 + case n < 100: + n += 10 + case n < 1000: + n += 100 + default: + n += 1000 + } + return n +} + +func TestBloomFilter_VaryingLengths(t *testing.T) { + h := newHarness(t) + var mediocre, good int + for n := 1; n < 10000; n = nextN(n) { + h.reset() + for i := 0; i < n; i++ { + h.addNum(uint32(i)) + } + h.build() + + got := h.filterLen() + want := (n * 10 / 8) + 40 + if got > want { + t.Errorf("filter len test failed, '%d' > '%d'", got, want) + } + + for i := 0; i < n; i++ { + h.assertNum(uint32(i), true, false) + } + + var rate float32 + for i := 0; i < 10000; i++ { + if h.assertNum(uint32(i+1000000000), true, true) { + rate++ + } + } + rate /= 10000 + if rate > 0.02 { + t.Errorf("false positive rate is more than 2%%, got %v, at len %d", rate, n) + } + if rate > 0.0125 { + mediocre++ + } else { + good++ + } + } + t.Logf("false positive rate: %d good, %d mediocre", good, mediocre) + if mediocre > good/5 { + t.Error("mediocre false positive rate is more than expected") + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/filter/filter.go b/vendor/github.com/syndtr/goleveldb/leveldb/filter/filter.go new file mode 100644 index 0000000..7a925c5 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/filter/filter.go @@ -0,0 +1,60 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package filter provides interface and implementation of probabilistic +// data structure. +// +// The filter is resposible for creating small filter from a set of keys. +// These filter will then used to test whether a key is a member of the set. +// In many cases, a filter can cut down the number of disk seeks from a +// handful to a single disk seek per DB.Get call. +package filter + +// Buffer is the interface that wraps basic Alloc, Write and WriteByte methods. +type Buffer interface { + // Alloc allocs n bytes of slice from the buffer. This also advancing + // write offset. + Alloc(n int) []byte + + // Write appends the contents of p to the buffer. + Write(p []byte) (n int, err error) + + // WriteByte appends the byte c to the buffer. + WriteByte(c byte) error +} + +// Filter is the filter. +type Filter interface { + // Name returns the name of this policy. + // + // Note that if the filter encoding changes in an incompatible way, + // the name returned by this method must be changed. Otherwise, old + // incompatible filters may be passed to methods of this type. + Name() string + + // NewGenerator creates a new filter generator. + NewGenerator() FilterGenerator + + // Contains returns true if the filter contains the given key. + // + // The filter are filters generated by the filter generator. + Contains(filter, key []byte) bool +} + +// FilterGenerator is the filter generator. +type FilterGenerator interface { + // Add adds a key to the filter generator. + // + // The key may become invalid after call to this method end, therefor + // key must be copied if implementation require keeping key for later + // use. The key should not modified directly, doing so may cause + // undefined results. + Add(key []byte) + + // Generate generates filters based on keys passed so far. After call + // to Generate the filter generator maybe resetted, depends on implementation. + Generate(b Buffer) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go new file mode 100644 index 0000000..a23ab05 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go @@ -0,0 +1,184 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/util" +) + +// BasicArray is the interface that wraps basic Len and Search method. +type BasicArray interface { + // Len returns length of the array. + Len() int + + // Search finds smallest index that point to a key that is greater + // than or equal to the given key. + Search(key []byte) int +} + +// Array is the interface that wraps BasicArray and basic Index method. +type Array interface { + BasicArray + + // Index returns key/value pair with index of i. + Index(i int) (key, value []byte) +} + +// Array is the interface that wraps BasicArray and basic Get method. +type ArrayIndexer interface { + BasicArray + + // Get returns a new data iterator with index of i. + Get(i int) Iterator +} + +type basicArrayIterator struct { + util.BasicReleaser + array BasicArray + pos int + err error +} + +func (i *basicArrayIterator) Valid() bool { + return i.pos >= 0 && i.pos < i.array.Len() && !i.Released() +} + +func (i *basicArrayIterator) First() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.array.Len() == 0 { + i.pos = -1 + return false + } + i.pos = 0 + return true +} + +func (i *basicArrayIterator) Last() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + n := i.array.Len() + if n == 0 { + i.pos = 0 + return false + } + i.pos = n - 1 + return true +} + +func (i *basicArrayIterator) Seek(key []byte) bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + n := i.array.Len() + if n == 0 { + i.pos = 0 + return false + } + i.pos = i.array.Search(key) + if i.pos >= n { + return false + } + return true +} + +func (i *basicArrayIterator) Next() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.pos++ + if n := i.array.Len(); i.pos >= n { + i.pos = n + return false + } + return true +} + +func (i *basicArrayIterator) Prev() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.pos-- + if i.pos < 0 { + i.pos = -1 + return false + } + return true +} + +func (i *basicArrayIterator) Error() error { return i.err } + +type arrayIterator struct { + basicArrayIterator + array Array + pos int + key, value []byte +} + +func (i *arrayIterator) updateKV() { + if i.pos == i.basicArrayIterator.pos { + return + } + i.pos = i.basicArrayIterator.pos + if i.Valid() { + i.key, i.value = i.array.Index(i.pos) + } else { + i.key = nil + i.value = nil + } +} + +func (i *arrayIterator) Key() []byte { + i.updateKV() + return i.key +} + +func (i *arrayIterator) Value() []byte { + i.updateKV() + return i.value +} + +type arrayIteratorIndexer struct { + basicArrayIterator + array ArrayIndexer +} + +func (i *arrayIteratorIndexer) Get() Iterator { + if i.Valid() { + return i.array.Get(i.basicArrayIterator.pos) + } + return nil +} + +// NewArrayIterator returns an iterator from the given array. +func NewArrayIterator(array Array) Iterator { + return &arrayIterator{ + basicArrayIterator: basicArrayIterator{array: array, pos: -1}, + array: array, + pos: -1, + } +} + +// NewArrayIndexer returns an index iterator from the given array. +func NewArrayIndexer(array ArrayIndexer) IteratorIndexer { + return &arrayIteratorIndexer{ + basicArrayIterator: basicArrayIterator{array: array, pos: -1}, + array: array, + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go new file mode 100644 index 0000000..f16d014 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go @@ -0,0 +1,30 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + . "github.com/onsi/ginkgo" + + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Array iterator", func() { + It("Should iterates and seeks correctly", func() { + // Build key/value. + kv := testutil.KeyValue_Generate(nil, 70, 1, 1, 5, 3, 3) + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewArrayIterator(kv), + } + testutil.DoIteratorTesting(&t) + }) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go new file mode 100644 index 0000000..939adbb --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go @@ -0,0 +1,242 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// IteratorIndexer is the interface that wraps CommonIterator and basic Get +// method. IteratorIndexer provides index for indexed iterator. +type IteratorIndexer interface { + CommonIterator + + // Get returns a new data iterator for the current position, or nil if + // done. + Get() Iterator +} + +type indexedIterator struct { + util.BasicReleaser + index IteratorIndexer + strict bool + + data Iterator + err error + errf func(err error) + closed bool +} + +func (i *indexedIterator) setData() { + if i.data != nil { + i.data.Release() + } + i.data = i.index.Get() +} + +func (i *indexedIterator) clearData() { + if i.data != nil { + i.data.Release() + } + i.data = nil +} + +func (i *indexedIterator) indexErr() { + if err := i.index.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + i.err = err + } +} + +func (i *indexedIterator) dataErr() bool { + if err := i.data.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + if i.strict || !errors.IsCorrupted(err) { + i.err = err + return true + } + } + return false +} + +func (i *indexedIterator) Valid() bool { + return i.data != nil && i.data.Valid() +} + +func (i *indexedIterator) First() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.First() { + i.indexErr() + i.clearData() + return false + } + i.setData() + return i.Next() +} + +func (i *indexedIterator) Last() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.Last() { + i.indexErr() + i.clearData() + return false + } + i.setData() + if !i.data.Last() { + if i.dataErr() { + return false + } + i.clearData() + return i.Prev() + } + return true +} + +func (i *indexedIterator) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.Seek(key) { + i.indexErr() + i.clearData() + return false + } + i.setData() + if !i.data.Seek(key) { + if i.dataErr() { + return false + } + i.clearData() + return i.Next() + } + return true +} + +func (i *indexedIterator) Next() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + switch { + case i.data != nil && !i.data.Next(): + if i.dataErr() { + return false + } + i.clearData() + fallthrough + case i.data == nil: + if !i.index.Next() { + i.indexErr() + return false + } + i.setData() + return i.Next() + } + return true +} + +func (i *indexedIterator) Prev() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + switch { + case i.data != nil && !i.data.Prev(): + if i.dataErr() { + return false + } + i.clearData() + fallthrough + case i.data == nil: + if !i.index.Prev() { + i.indexErr() + return false + } + i.setData() + if !i.data.Last() { + if i.dataErr() { + return false + } + i.clearData() + return i.Prev() + } + } + return true +} + +func (i *indexedIterator) Key() []byte { + if i.data == nil { + return nil + } + return i.data.Key() +} + +func (i *indexedIterator) Value() []byte { + if i.data == nil { + return nil + } + return i.data.Value() +} + +func (i *indexedIterator) Release() { + i.clearData() + i.index.Release() + i.BasicReleaser.Release() +} + +func (i *indexedIterator) Error() error { + if i.err != nil { + return i.err + } + if err := i.index.Error(); err != nil { + return err + } + return nil +} + +func (i *indexedIterator) SetErrorCallback(f func(err error)) { + i.errf = f +} + +// NewIndexedIterator returns an 'indexed iterator'. An index is iterator +// that returns another iterator, a 'data iterator'. A 'data iterator' is the +// iterator that contains actual key/value pairs. +// +// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true) +// won't be ignored and will halt 'indexed iterator', otherwise the iterator will +// continue to the next 'data iterator'. Corruption on 'index iterator' will not be +// ignored and will halt the iterator. +func NewIndexedIterator(index IteratorIndexer, strict bool) Iterator { + return &indexedIterator{index: index, strict: strict} +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go new file mode 100644 index 0000000..fde8016 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + "sort" + + . "github.com/onsi/ginkgo" + + "github.com/syndtr/goleveldb/leveldb/comparer" + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +type keyValue struct { + key []byte + testutil.KeyValue +} + +type keyValueIndex []keyValue + +func (x keyValueIndex) Search(key []byte) int { + return sort.Search(x.Len(), func(i int) bool { + return comparer.DefaultComparer.Compare(x[i].key, key) >= 0 + }) +} + +func (x keyValueIndex) Len() int { return len(x) } +func (x keyValueIndex) Index(i int) (key, value []byte) { return x[i].key, nil } +func (x keyValueIndex) Get(i int) Iterator { return NewArrayIterator(x[i]) } + +var _ = testutil.Defer(func() { + Describe("Indexed iterator", func() { + Test := func(n ...int) func() { + if len(n) == 0 { + rnd := testutil.NewRand() + n = make([]int, rnd.Intn(17)+3) + for i := range n { + n[i] = rnd.Intn(19) + 1 + } + } + + return func() { + It("Should iterates and seeks correctly", func(done Done) { + // Build key/value. + index := make(keyValueIndex, len(n)) + sum := 0 + for _, x := range n { + sum += x + } + kv := testutil.KeyValue_Generate(nil, sum, 1, 1, 10, 4, 4) + for i, j := 0, 0; i < len(n); i++ { + for x := n[i]; x > 0; x-- { + key, value := kv.Index(j) + index[i].key = key + index[i].Put(key, value) + j++ + } + } + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewIndexedIterator(NewArrayIndexer(index), true), + } + testutil.DoIteratorTesting(&t) + done <- true + }, 15.0) + } + } + + Describe("with 100 keys", Test(100)) + Describe("with 50-50 keys", Test(50, 50)) + Describe("with 50-1 keys", Test(50, 1)) + Describe("with 50-1-50 keys", Test(50, 1, 50)) + Describe("with 1-50 keys", Test(1, 50)) + Describe("with random N-keys", Test()) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter.go new file mode 100644 index 0000000..96fb0f6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter.go @@ -0,0 +1,132 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package iterator provides interface and implementation to traverse over +// contents of a database. +package iterator + +import ( + "errors" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + ErrIterReleased = errors.New("leveldb/iterator: iterator released") +) + +// IteratorSeeker is the interface that wraps the 'seeks method'. +type IteratorSeeker interface { + // First moves the iterator to the first key/value pair. If the iterator + // only contains one key/value pair then First and Last would moves + // to the same key/value pair. + // It returns whether such pair exist. + First() bool + + // Last moves the iterator to the last key/value pair. If the iterator + // only contains one key/value pair then First and Last would moves + // to the same key/value pair. + // It returns whether such pair exist. + Last() bool + + // Seek moves the iterator to the first key/value pair whose key is greater + // than or equal to the given key. + // It returns whether such pair exist. + // + // It is safe to modify the contents of the argument after Seek returns. + Seek(key []byte) bool + + // Next moves the iterator to the next key/value pair. + // It returns false if the iterator is exhausted. + Next() bool + + // Prev moves the iterator to the previous key/value pair. + // It returns false if the iterator is exhausted. + Prev() bool +} + +// CommonIterator is the interface that wraps common iterator methods. +type CommonIterator interface { + IteratorSeeker + + // util.Releaser is the interface that wraps basic Release method. + // When called Release will releases any resources associated with the + // iterator. + util.Releaser + + // util.ReleaseSetter is the interface that wraps the basic SetReleaser + // method. + util.ReleaseSetter + + // TODO: Remove this when ready. + Valid() bool + + // Error returns any accumulated error. Exhausting all the key/value pairs + // is not considered to be an error. + Error() error +} + +// Iterator iterates over a DB's key/value pairs in key order. +// +// When encounter an error any 'seeks method' will return false and will +// yield no key/value pairs. The error can be queried by calling the Error +// method. Calling Release is still necessary. +// +// An iterator must be released after use, but it is not necessary to read +// an iterator until exhaustion. +// Also, an iterator is not necessarily safe for concurrent use, but it is +// safe to use multiple iterators concurrently, with each in a dedicated +// goroutine. +type Iterator interface { + CommonIterator + + // Key returns the key of the current key/value pair, or nil if done. + // The caller should not modify the contents of the returned slice, and + // its contents may change on the next call to any 'seeks method'. + Key() []byte + + // Value returns the value of the current key/value pair, or nil if done. + // The caller should not modify the contents of the returned slice, and + // its contents may change on the next call to any 'seeks method'. + Value() []byte +} + +// ErrorCallbackSetter is the interface that wraps basic SetErrorCallback +// method. +// +// ErrorCallbackSetter implemented by indexed and merged iterator. +type ErrorCallbackSetter interface { + // SetErrorCallback allows set an error callback of the corresponding + // iterator. Use nil to clear the callback. + SetErrorCallback(f func(err error)) +} + +type emptyIterator struct { + util.BasicReleaser + err error +} + +func (i *emptyIterator) rErr() { + if i.err == nil && i.Released() { + i.err = ErrIterReleased + } +} + +func (*emptyIterator) Valid() bool { return false } +func (i *emptyIterator) First() bool { i.rErr(); return false } +func (i *emptyIterator) Last() bool { i.rErr(); return false } +func (i *emptyIterator) Seek(key []byte) bool { i.rErr(); return false } +func (i *emptyIterator) Next() bool { i.rErr(); return false } +func (i *emptyIterator) Prev() bool { i.rErr(); return false } +func (*emptyIterator) Key() []byte { return nil } +func (*emptyIterator) Value() []byte { return nil } +func (i *emptyIterator) Error() error { return i.err } + +// NewEmptyIterator creates an empty iterator. The err parameter can be +// nil, but if not nil the given err will be returned by Error method. +func NewEmptyIterator(err error) Iterator { + return &emptyIterator{err: err} +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go new file mode 100644 index 0000000..5ef8d5b --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go @@ -0,0 +1,11 @@ +package iterator_test + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestIterator(t *testing.T) { + testutil.RunSuite(t, "Iterator Suite") +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go new file mode 100644 index 0000000..1a7e29d --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go @@ -0,0 +1,304 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +type mergedIterator struct { + cmp comparer.Comparer + iters []Iterator + strict bool + + keys [][]byte + index int + dir dir + err error + errf func(err error) + releaser util.Releaser +} + +func assertKey(key []byte) []byte { + if key == nil { + panic("leveldb/iterator: nil key") + } + return key +} + +func (i *mergedIterator) iterErr(iter Iterator) bool { + if err := iter.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + if i.strict || !errors.IsCorrupted(err) { + i.err = err + return true + } + } + return false +} + +func (i *mergedIterator) Valid() bool { + return i.err == nil && i.dir > dirEOI +} + +func (i *mergedIterator) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.First(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirSOI + return i.next() +} + +func (i *mergedIterator) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.Last(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirEOI + return i.prev() +} + +func (i *mergedIterator) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.Seek(key): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirSOI + return i.next() +} + +func (i *mergedIterator) next() bool { + var key []byte + if i.dir == dirForward { + key = i.keys[i.index] + } + for x, tkey := range i.keys { + if tkey != nil && (key == nil || i.cmp.Compare(tkey, key) < 0) { + key = tkey + i.index = x + } + } + if key == nil { + i.dir = dirEOI + return false + } + i.dir = dirForward + return true +} + +func (i *mergedIterator) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirSOI: + return i.First() + case dirBackward: + key := append([]byte{}, i.keys[i.index]...) + if !i.Seek(key) { + return false + } + return i.Next() + } + + x := i.index + iter := i.iters[x] + switch { + case iter.Next(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + return i.next() +} + +func (i *mergedIterator) prev() bool { + var key []byte + if i.dir == dirBackward { + key = i.keys[i.index] + } + for x, tkey := range i.keys { + if tkey != nil && (key == nil || i.cmp.Compare(tkey, key) > 0) { + key = tkey + i.index = x + } + } + if key == nil { + i.dir = dirSOI + return false + } + i.dir = dirBackward + return true +} + +func (i *mergedIterator) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirEOI: + return i.Last() + case dirForward: + key := append([]byte{}, i.keys[i.index]...) + for x, iter := range i.iters { + if x == i.index { + continue + } + seek := iter.Seek(key) + switch { + case seek && iter.Prev(), !seek && iter.Last(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + } + + x := i.index + iter := i.iters[x] + switch { + case iter.Prev(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + return i.prev() +} + +func (i *mergedIterator) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.keys[i.index] +} + +func (i *mergedIterator) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.iters[i.index].Value() +} + +func (i *mergedIterator) Release() { + if i.dir != dirReleased { + i.dir = dirReleased + for _, iter := range i.iters { + iter.Release() + } + i.iters = nil + i.keys = nil + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + } +} + +func (i *mergedIterator) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *mergedIterator) Error() error { + return i.err +} + +func (i *mergedIterator) SetErrorCallback(f func(err error)) { + i.errf = f +} + +// NewMergedIterator returns an iterator that merges its input. Walking the +// resultant iterator will return all key/value pairs of all input iterators +// in strictly increasing key order, as defined by cmp. +// The input's key ranges may overlap, but there are assumed to be no duplicate +// keys: if iters[i] contains a key k then iters[j] will not contain that key k. +// None of the iters may be nil. +// +// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true) +// won't be ignored and will halt 'merged iterator', otherwise the iterator will +// continue to the next 'input iterator'. +func NewMergedIterator(iters []Iterator, cmp comparer.Comparer, strict bool) Iterator { + return &mergedIterator{ + iters: iters, + cmp: cmp, + strict: strict, + keys: make([][]byte, len(iters)), + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go new file mode 100644 index 0000000..ee40881 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go @@ -0,0 +1,60 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Merged iterator", func() { + Test := func(filled int, empty int) func() { + return func() { + It("Should iterates and seeks correctly", func(done Done) { + rnd := testutil.NewRand() + + // Build key/value. + filledKV := make([]testutil.KeyValue, filled) + kv := testutil.KeyValue_Generate(nil, 100, 1, 1, 10, 4, 4) + kv.Iterate(func(i int, key, value []byte) { + filledKV[rnd.Intn(filled)].Put(key, value) + }) + + // Create itearators. + iters := make([]Iterator, filled+empty) + for i := range iters { + if empty == 0 || (rnd.Int()%2 == 0 && filled > 0) { + filled-- + Expect(filledKV[filled].Len()).ShouldNot(BeZero()) + iters[i] = NewArrayIterator(filledKV[filled]) + } else { + empty-- + iters[i] = NewEmptyIterator(nil) + } + } + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewMergedIterator(iters, comparer.DefaultComparer, true), + } + testutil.DoIteratorTesting(&t) + done <- true + }, 15.0) + } + } + + Describe("with three, all filled iterators", Test(3, 0)) + Describe("with one filled, one empty iterators", Test(1, 1)) + Describe("with one filled, two empty iterators", Test(1, 2)) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal.go b/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal.go new file mode 100644 index 0000000..d094c3d --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal.go @@ -0,0 +1,524 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Taken from: https://code.google.com/p/leveldb-go/source/browse/leveldb/record/record.go?r=1d5ccbe03246da926391ee12d1c6caae054ff4b0 +// License, authors and contributors informations can be found at bellow URLs respectively: +// https://code.google.com/p/leveldb-go/source/browse/LICENSE +// https://code.google.com/p/leveldb-go/source/browse/AUTHORS +// https://code.google.com/p/leveldb-go/source/browse/CONTRIBUTORS + +// Package journal reads and writes sequences of journals. Each journal is a stream +// of bytes that completes before the next journal starts. +// +// When reading, call Next to obtain an io.Reader for the next journal. Next will +// return io.EOF when there are no more journals. It is valid to call Next +// without reading the current journal to exhaustion. +// +// When writing, call Next to obtain an io.Writer for the next journal. Calling +// Next finishes the current journal. Call Close to finish the final journal. +// +// Optionally, call Flush to finish the current journal and flush the underlying +// writer without starting a new journal. To start a new journal after flushing, +// call Next. +// +// Neither Readers or Writers are safe to use concurrently. +// +// Example code: +// func read(r io.Reader) ([]string, error) { +// var ss []string +// journals := journal.NewReader(r, nil, true, true) +// for { +// j, err := journals.Next() +// if err == io.EOF { +// break +// } +// if err != nil { +// return nil, err +// } +// s, err := ioutil.ReadAll(j) +// if err != nil { +// return nil, err +// } +// ss = append(ss, string(s)) +// } +// return ss, nil +// } +// +// func write(w io.Writer, ss []string) error { +// journals := journal.NewWriter(w) +// for _, s := range ss { +// j, err := journals.Next() +// if err != nil { +// return err +// } +// if _, err := j.Write([]byte(s)), err != nil { +// return err +// } +// } +// return journals.Close() +// } +// +// The wire format is that the stream is divided into 32KiB blocks, and each +// block contains a number of tightly packed chunks. Chunks cannot cross block +// boundaries. The last block may be shorter than 32 KiB. Any unused bytes in a +// block must be zero. +// +// A journal maps to one or more chunks. Each chunk has a 7 byte header (a 4 +// byte checksum, a 2 byte little-endian uint16 length, and a 1 byte chunk type) +// followed by a payload. The checksum is over the chunk type and the payload. +// +// There are four chunk types: whether the chunk is the full journal, or the +// first, middle or last chunk of a multi-chunk journal. A multi-chunk journal +// has one first chunk, zero or more middle chunks, and one last chunk. +// +// The wire format allows for limited recovery in the face of data corruption: +// on a format error (such as a checksum mismatch), the reader moves to the +// next block and looks for the next full or first chunk. +package journal + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// These constants are part of the wire format and should not be changed. +const ( + fullChunkType = 1 + firstChunkType = 2 + middleChunkType = 3 + lastChunkType = 4 +) + +const ( + blockSize = 32 * 1024 + headerSize = 7 +) + +type flusher interface { + Flush() error +} + +// ErrCorrupted is the error type that generated by corrupted block or chunk. +type ErrCorrupted struct { + Size int + Reason string +} + +func (e *ErrCorrupted) Error() string { + return fmt.Sprintf("leveldb/journal: block/chunk corrupted: %s (%d bytes)", e.Reason, e.Size) +} + +// Dropper is the interface that wrap simple Drop method. The Drop +// method will be called when the journal reader dropping a block or chunk. +type Dropper interface { + Drop(err error) +} + +// Reader reads journals from an underlying io.Reader. +type Reader struct { + // r is the underlying reader. + r io.Reader + // the dropper. + dropper Dropper + // strict flag. + strict bool + // checksum flag. + checksum bool + // seq is the sequence number of the current journal. + seq int + // buf[i:j] is the unread portion of the current chunk's payload. + // The low bound, i, excludes the chunk header. + i, j int + // n is the number of bytes of buf that are valid. Once reading has started, + // only the final block can have n < blockSize. + n int + // last is whether the current chunk is the last chunk of the journal. + last bool + // err is any accumulated error. + err error + // buf is the buffer. + buf [blockSize]byte +} + +// NewReader returns a new reader. The dropper may be nil, and if +// strict is true then corrupted or invalid chunk will halt the journal +// reader entirely. +func NewReader(r io.Reader, dropper Dropper, strict, checksum bool) *Reader { + return &Reader{ + r: r, + dropper: dropper, + strict: strict, + checksum: checksum, + last: true, + } +} + +var errSkip = errors.New("leveldb/journal: skipped") + +func (r *Reader) corrupt(n int, reason string, skip bool) error { + if r.dropper != nil { + r.dropper.Drop(&ErrCorrupted{n, reason}) + } + if r.strict && !skip { + r.err = errors.NewErrCorrupted(storage.FileDesc{}, &ErrCorrupted{n, reason}) + return r.err + } + return errSkip +} + +// nextChunk sets r.buf[r.i:r.j] to hold the next chunk's payload, reading the +// next block into the buffer if necessary. +func (r *Reader) nextChunk(first bool) error { + for { + if r.j+headerSize <= r.n { + checksum := binary.LittleEndian.Uint32(r.buf[r.j+0 : r.j+4]) + length := binary.LittleEndian.Uint16(r.buf[r.j+4 : r.j+6]) + chunkType := r.buf[r.j+6] + unprocBlock := r.n - r.j + if checksum == 0 && length == 0 && chunkType == 0 { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(unprocBlock, "zero header", false) + } + if chunkType < fullChunkType || chunkType > lastChunkType { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(unprocBlock, fmt.Sprintf("invalid chunk type %#x", chunkType), false) + } + r.i = r.j + headerSize + r.j = r.j + headerSize + int(length) + if r.j > r.n { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(unprocBlock, "chunk length overflows block", false) + } else if r.checksum && checksum != util.NewCRC(r.buf[r.i-1:r.j]).Value() { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(unprocBlock, "checksum mismatch", false) + } + if first && chunkType != fullChunkType && chunkType != firstChunkType { + chunkLength := (r.j - r.i) + headerSize + r.i = r.j + // Report the error, but skip it. + return r.corrupt(chunkLength, "orphan chunk", true) + } + r.last = chunkType == fullChunkType || chunkType == lastChunkType + return nil + } + + // The last block. + if r.n < blockSize && r.n > 0 { + if !first { + return r.corrupt(0, "missing chunk part", false) + } + r.err = io.EOF + return r.err + } + + // Read block. + n, err := io.ReadFull(r.r, r.buf[:]) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return err + } + if n == 0 { + if !first { + return r.corrupt(0, "missing chunk part", false) + } + r.err = io.EOF + return r.err + } + r.i, r.j, r.n = 0, 0, n + } +} + +// Next returns a reader for the next journal. It returns io.EOF if there are no +// more journals. The reader returned becomes stale after the next Next call, +// and should no longer be used. If strict is false, the reader will returns +// io.ErrUnexpectedEOF error when found corrupted journal. +func (r *Reader) Next() (io.Reader, error) { + r.seq++ + if r.err != nil { + return nil, r.err + } + r.i = r.j + for { + if err := r.nextChunk(true); err == nil { + break + } else if err != errSkip { + return nil, err + } + } + return &singleReader{r, r.seq, nil}, nil +} + +// Reset resets the journal reader, allows reuse of the journal reader. Reset returns +// last accumulated error. +func (r *Reader) Reset(reader io.Reader, dropper Dropper, strict, checksum bool) error { + r.seq++ + err := r.err + r.r = reader + r.dropper = dropper + r.strict = strict + r.checksum = checksum + r.i = 0 + r.j = 0 + r.n = 0 + r.last = true + r.err = nil + return err +} + +type singleReader struct { + r *Reader + seq int + err error +} + +func (x *singleReader) Read(p []byte) (int, error) { + r := x.r + if r.seq != x.seq { + return 0, errors.New("leveldb/journal: stale reader") + } + if x.err != nil { + return 0, x.err + } + if r.err != nil { + return 0, r.err + } + for r.i == r.j { + if r.last { + return 0, io.EOF + } + x.err = r.nextChunk(false) + if x.err != nil { + if x.err == errSkip { + x.err = io.ErrUnexpectedEOF + } + return 0, x.err + } + } + n := copy(p, r.buf[r.i:r.j]) + r.i += n + return n, nil +} + +func (x *singleReader) ReadByte() (byte, error) { + r := x.r + if r.seq != x.seq { + return 0, errors.New("leveldb/journal: stale reader") + } + if x.err != nil { + return 0, x.err + } + if r.err != nil { + return 0, r.err + } + for r.i == r.j { + if r.last { + return 0, io.EOF + } + x.err = r.nextChunk(false) + if x.err != nil { + if x.err == errSkip { + x.err = io.ErrUnexpectedEOF + } + return 0, x.err + } + } + c := r.buf[r.i] + r.i++ + return c, nil +} + +// Writer writes journals to an underlying io.Writer. +type Writer struct { + // w is the underlying writer. + w io.Writer + // seq is the sequence number of the current journal. + seq int + // f is w as a flusher. + f flusher + // buf[i:j] is the bytes that will become the current chunk. + // The low bound, i, includes the chunk header. + i, j int + // buf[:written] has already been written to w. + // written is zero unless Flush has been called. + written int + // first is whether the current chunk is the first chunk of the journal. + first bool + // pending is whether a chunk is buffered but not yet written. + pending bool + // err is any accumulated error. + err error + // buf is the buffer. + buf [blockSize]byte +} + +// NewWriter returns a new Writer. +func NewWriter(w io.Writer) *Writer { + f, _ := w.(flusher) + return &Writer{ + w: w, + f: f, + } +} + +// fillHeader fills in the header for the pending chunk. +func (w *Writer) fillHeader(last bool) { + if w.i+headerSize > w.j || w.j > blockSize { + panic("leveldb/journal: bad writer state") + } + if last { + if w.first { + w.buf[w.i+6] = fullChunkType + } else { + w.buf[w.i+6] = lastChunkType + } + } else { + if w.first { + w.buf[w.i+6] = firstChunkType + } else { + w.buf[w.i+6] = middleChunkType + } + } + binary.LittleEndian.PutUint32(w.buf[w.i+0:w.i+4], util.NewCRC(w.buf[w.i+6:w.j]).Value()) + binary.LittleEndian.PutUint16(w.buf[w.i+4:w.i+6], uint16(w.j-w.i-headerSize)) +} + +// writeBlock writes the buffered block to the underlying writer, and reserves +// space for the next chunk's header. +func (w *Writer) writeBlock() { + _, w.err = w.w.Write(w.buf[w.written:]) + w.i = 0 + w.j = headerSize + w.written = 0 +} + +// writePending finishes the current journal and writes the buffer to the +// underlying writer. +func (w *Writer) writePending() { + if w.err != nil { + return + } + if w.pending { + w.fillHeader(true) + w.pending = false + } + _, w.err = w.w.Write(w.buf[w.written:w.j]) + w.written = w.j +} + +// Close finishes the current journal and closes the writer. +func (w *Writer) Close() error { + w.seq++ + w.writePending() + if w.err != nil { + return w.err + } + w.err = errors.New("leveldb/journal: closed Writer") + return nil +} + +// Flush finishes the current journal, writes to the underlying writer, and +// flushes it if that writer implements interface{ Flush() error }. +func (w *Writer) Flush() error { + w.seq++ + w.writePending() + if w.err != nil { + return w.err + } + if w.f != nil { + w.err = w.f.Flush() + return w.err + } + return nil +} + +// Reset resets the journal writer, allows reuse of the journal writer. Reset +// will also closes the journal writer if not already. +func (w *Writer) Reset(writer io.Writer) (err error) { + w.seq++ + if w.err == nil { + w.writePending() + err = w.err + } + w.w = writer + w.f, _ = writer.(flusher) + w.i = 0 + w.j = 0 + w.written = 0 + w.first = false + w.pending = false + w.err = nil + return +} + +// Next returns a writer for the next journal. The writer returned becomes stale +// after the next Close, Flush or Next call, and should no longer be used. +func (w *Writer) Next() (io.Writer, error) { + w.seq++ + if w.err != nil { + return nil, w.err + } + if w.pending { + w.fillHeader(true) + } + w.i = w.j + w.j = w.j + headerSize + // Check if there is room in the block for the header. + if w.j > blockSize { + // Fill in the rest of the block with zeroes. + for k := w.i; k < blockSize; k++ { + w.buf[k] = 0 + } + w.writeBlock() + if w.err != nil { + return nil, w.err + } + } + w.first = true + w.pending = true + return singleWriter{w, w.seq}, nil +} + +type singleWriter struct { + w *Writer + seq int +} + +func (x singleWriter) Write(p []byte) (int, error) { + w := x.w + if w.seq != x.seq { + return 0, errors.New("leveldb/journal: stale writer") + } + if w.err != nil { + return 0, w.err + } + n0 := len(p) + for len(p) > 0 { + // Write a block, if it is full. + if w.j == blockSize { + w.fillHeader(false) + w.writeBlock() + if w.err != nil { + return 0, w.err + } + w.first = false + } + // Copy bytes into the buffer. + n := copy(w.buf[w.j:], p) + w.j += n + p = p[n:] + } + return n0, nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go new file mode 100644 index 0000000..0fcf225 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go @@ -0,0 +1,818 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Taken from: https://code.google.com/p/leveldb-go/source/browse/leveldb/record/record_test.go?r=df1fa28f7f3be6c3935548169002309c12967135 +// License, authors and contributors informations can be found at bellow URLs respectively: +// https://code.google.com/p/leveldb-go/source/browse/LICENSE +// https://code.google.com/p/leveldb-go/source/browse/AUTHORS +// https://code.google.com/p/leveldb-go/source/browse/CONTRIBUTORS + +package journal + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "math/rand" + "strings" + "testing" +) + +type dropper struct { + t *testing.T +} + +func (d dropper) Drop(err error) { + d.t.Log(err) +} + +func short(s string) string { + if len(s) < 64 { + return s + } + return fmt.Sprintf("%s...(skipping %d bytes)...%s", s[:20], len(s)-40, s[len(s)-20:]) +} + +// big returns a string of length n, composed of repetitions of partial. +func big(partial string, n int) string { + return strings.Repeat(partial, n/len(partial)+1)[:n] +} + +func TestEmpty(t *testing.T) { + buf := new(bytes.Buffer) + r := NewReader(buf, dropper{t}, true, true) + if _, err := r.Next(); err != io.EOF { + t.Fatalf("got %v, want %v", err, io.EOF) + } +} + +func testGenerator(t *testing.T, reset func(), gen func() (string, bool)) { + buf := new(bytes.Buffer) + + reset() + w := NewWriter(buf) + for { + s, ok := gen() + if !ok { + break + } + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write([]byte(s)); err != nil { + t.Fatal(err) + } + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + reset() + r := NewReader(buf, dropper{t}, true, true) + for { + s, ok := gen() + if !ok { + break + } + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + x, err := ioutil.ReadAll(rr) + if err != nil { + t.Fatal(err) + } + if string(x) != s { + t.Fatalf("got %q, want %q", short(string(x)), short(s)) + } + } + if _, err := r.Next(); err != io.EOF { + t.Fatalf("got %v, want %v", err, io.EOF) + } +} + +func testLiterals(t *testing.T, s []string) { + var i int + reset := func() { + i = 0 + } + gen := func() (string, bool) { + if i == len(s) { + return "", false + } + i++ + return s[i-1], true + } + testGenerator(t, reset, gen) +} + +func TestMany(t *testing.T) { + const n = 1e5 + var i int + reset := func() { + i = 0 + } + gen := func() (string, bool) { + if i == n { + return "", false + } + i++ + return fmt.Sprintf("%d.", i-1), true + } + testGenerator(t, reset, gen) +} + +func TestRandom(t *testing.T) { + const n = 1e2 + var ( + i int + r *rand.Rand + ) + reset := func() { + i, r = 0, rand.New(rand.NewSource(0)) + } + gen := func() (string, bool) { + if i == n { + return "", false + } + i++ + return strings.Repeat(string(uint8(i)), r.Intn(2*blockSize+16)), true + } + testGenerator(t, reset, gen) +} + +func TestBasic(t *testing.T) { + testLiterals(t, []string{ + strings.Repeat("a", 1000), + strings.Repeat("b", 97270), + strings.Repeat("c", 8000), + }) +} + +func TestBoundary(t *testing.T) { + for i := blockSize - 16; i < blockSize+16; i++ { + s0 := big("abcd", i) + for j := blockSize - 16; j < blockSize+16; j++ { + s1 := big("ABCDE", j) + testLiterals(t, []string{s0, s1}) + testLiterals(t, []string{s0, "", s1}) + testLiterals(t, []string{s0, "x", s1}) + } + } +} + +func TestFlush(t *testing.T) { + buf := new(bytes.Buffer) + w := NewWriter(buf) + // Write a couple of records. Everything should still be held + // in the record.Writer buffer, so that buf.Len should be 0. + w0, _ := w.Next() + w0.Write([]byte("0")) + w1, _ := w.Next() + w1.Write([]byte("11")) + if got, want := buf.Len(), 0; got != want { + t.Fatalf("buffer length #0: got %d want %d", got, want) + } + // Flush the record.Writer buffer, which should yield 17 bytes. + // 17 = 2*7 + 1 + 2, which is two headers and 1 + 2 payload bytes. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 17; got != want { + t.Fatalf("buffer length #1: got %d want %d", got, want) + } + // Do another write, one that isn't large enough to complete the block. + // The write should not have flowed through to buf. + w2, _ := w.Next() + w2.Write(bytes.Repeat([]byte("2"), 10000)) + if got, want := buf.Len(), 17; got != want { + t.Fatalf("buffer length #2: got %d want %d", got, want) + } + // Flushing should get us up to 10024 bytes written. + // 10024 = 17 + 7 + 10000. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 10024; got != want { + t.Fatalf("buffer length #3: got %d want %d", got, want) + } + // Do a bigger write, one that completes the current block. + // We should now have 32768 bytes (a complete block), without + // an explicit flush. + w3, _ := w.Next() + w3.Write(bytes.Repeat([]byte("3"), 40000)) + if got, want := buf.Len(), 32768; got != want { + t.Fatalf("buffer length #4: got %d want %d", got, want) + } + // Flushing should get us up to 50038 bytes written. + // 50038 = 10024 + 2*7 + 40000. There are two headers because + // the one record was split into two chunks. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 50038; got != want { + t.Fatalf("buffer length #5: got %d want %d", got, want) + } + // Check that reading those records give the right lengths. + r := NewReader(buf, dropper{t}, true, true) + wants := []int64{1, 2, 10000, 40000} + for i, want := range wants { + rr, _ := r.Next() + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #%d: %v", i, err) + } + if n != want { + t.Fatalf("read #%d: got %d bytes want %d", i, n, want) + } + } +} + +func TestNonExhaustiveRead(t *testing.T) { + const n = 100 + buf := new(bytes.Buffer) + p := make([]byte, 10) + rnd := rand.New(rand.NewSource(1)) + + w := NewWriter(buf) + for i := 0; i < n; i++ { + length := len(p) + rnd.Intn(3*blockSize) + s := string(uint8(i)) + "123456789abcdefgh" + ww, _ := w.Next() + ww.Write([]byte(big(s, length))) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + r := NewReader(buf, dropper{t}, true, true) + for i := 0; i < n; i++ { + rr, _ := r.Next() + _, err := io.ReadFull(rr, p) + if err != nil { + t.Fatal(err) + } + want := string(uint8(i)) + "123456789" + if got := string(p); got != want { + t.Fatalf("read #%d: got %q want %q", i, got, want) + } + } +} + +func TestStaleReader(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + w0, err := w.Next() + if err != nil { + t.Fatal(err) + } + w0.Write([]byte("0")) + w1, err := w.Next() + if err != nil { + t.Fatal(err) + } + w1.Write([]byte("11")) + if err := w.Close(); err != nil { + t.Fatal(err) + } + + r := NewReader(buf, dropper{t}, true, true) + r0, err := r.Next() + if err != nil { + t.Fatal(err) + } + r1, err := r.Next() + if err != nil { + t.Fatal(err) + } + p := make([]byte, 1) + if _, err := r0.Read(p); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale read #0: unexpected error: %v", err) + } + if _, err := r1.Read(p); err != nil { + t.Fatalf("fresh read #1: got %v want nil error", err) + } + if p[0] != '1' { + t.Fatalf("fresh read #1: byte contents: got '%c' want '1'", p[0]) + } +} + +func TestStaleWriter(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + w0, err := w.Next() + if err != nil { + t.Fatal(err) + } + w1, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := w0.Write([]byte("0")); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale write #0: unexpected error: %v", err) + } + if _, err := w1.Write([]byte("11")); err != nil { + t.Fatalf("fresh write #1: got %v want nil error", err) + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + if _, err := w1.Write([]byte("0")); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale write #1: unexpected error: %v", err) + } +} + +func TestCorrupt_MissingLastBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-1024)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + // Cut the last block. + b := buf.Bytes()[:blockSize] + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read. + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if n != blockSize-1024 { + t.Fatalf("read #0: got %d bytes want %d", n, blockSize-1024) + } + + // Second read. + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedFirstBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #0. + for i := 0; i < 1024; i++ { + b[i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (third record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize-headerSize) + 2; n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedMiddleBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #1. + for i := 0; i < 1024; i++ { + b[blockSize+i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + // Third read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #2: %v", err) + } + if want := int64(blockSize-headerSize) + 2; n != want { + t.Fatalf("read #2: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedLastBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #3. + for i := len(b) - 1; i > len(b)-1024; i-- { + b[i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize - headerSize); n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + // Third read (third record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #2: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #2: got %d bytes want %d", n, want) + } + + // Fourth read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #3: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_FirstChuckLengthOverflow(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting record #1. + x := blockSize + binary.LittleEndian.PutUint16(b[x+4:], 0xffff) + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_MiddleChuckLengthOverflow(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting record #1. + x := blockSize/2 + headerSize + binary.LittleEndian.PutUint16(b[x+4:], 0xffff) + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (third record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/key.go b/vendor/github.com/syndtr/goleveldb/leveldb/key.go new file mode 100644 index 0000000..ad8f51e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/key.go @@ -0,0 +1,143 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "encoding/binary" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +// ErrInternalKeyCorrupted records internal key corruption. +type ErrInternalKeyCorrupted struct { + Ikey []byte + Reason string +} + +func (e *ErrInternalKeyCorrupted) Error() string { + return fmt.Sprintf("leveldb: internal key %q corrupted: %s", e.Ikey, e.Reason) +} + +func newErrInternalKeyCorrupted(ikey []byte, reason string) error { + return errors.NewErrCorrupted(storage.FileDesc{}, &ErrInternalKeyCorrupted{append([]byte{}, ikey...), reason}) +} + +type keyType uint + +func (kt keyType) String() string { + switch kt { + case keyTypeDel: + return "d" + case keyTypeVal: + return "v" + } + return fmt.Sprintf("", uint(kt)) +} + +// Value types encoded as the last component of internal keys. +// Don't modify; this value are saved to disk. +const ( + keyTypeDel = keyType(0) + keyTypeVal = keyType(1) +) + +// keyTypeSeek defines the keyType that should be passed when constructing an +// internal key for seeking to a particular sequence number (since we +// sort sequence numbers in decreasing order and the value type is +// embedded as the low 8 bits in the sequence number in internal keys, +// we need to use the highest-numbered ValueType, not the lowest). +const keyTypeSeek = keyTypeVal + +const ( + // Maximum value possible for sequence number; the 8-bits are + // used by value type, so its can packed together in single + // 64-bit integer. + keyMaxSeq = (uint64(1) << 56) - 1 + // Maximum value possible for packed sequence number and type. + keyMaxNum = (keyMaxSeq << 8) | uint64(keyTypeSeek) +) + +// Maximum number encoded in bytes. +var keyMaxNumBytes = make([]byte, 8) + +func init() { + binary.LittleEndian.PutUint64(keyMaxNumBytes, keyMaxNum) +} + +type internalKey []byte + +func makeInternalKey(dst, ukey []byte, seq uint64, kt keyType) internalKey { + if seq > keyMaxSeq { + panic("leveldb: invalid sequence number") + } else if kt > keyTypeVal { + panic("leveldb: invalid type") + } + + dst = ensureBuffer(dst, len(ukey)+8) + copy(dst, ukey) + binary.LittleEndian.PutUint64(dst[len(ukey):], (seq<<8)|uint64(kt)) + return internalKey(dst) +} + +func parseInternalKey(ik []byte) (ukey []byte, seq uint64, kt keyType, err error) { + if len(ik) < 8 { + return nil, 0, 0, newErrInternalKeyCorrupted(ik, "invalid length") + } + num := binary.LittleEndian.Uint64(ik[len(ik)-8:]) + seq, kt = uint64(num>>8), keyType(num&0xff) + if kt > keyTypeVal { + return nil, 0, 0, newErrInternalKeyCorrupted(ik, "invalid type") + } + ukey = ik[:len(ik)-8] + return +} + +func validInternalKey(ik []byte) bool { + _, _, _, err := parseInternalKey(ik) + return err == nil +} + +func (ik internalKey) assert() { + if ik == nil { + panic("leveldb: nil internalKey") + } + if len(ik) < 8 { + panic(fmt.Sprintf("leveldb: internal key %q, len=%d: invalid length", []byte(ik), len(ik))) + } +} + +func (ik internalKey) ukey() []byte { + ik.assert() + return ik[:len(ik)-8] +} + +func (ik internalKey) num() uint64 { + ik.assert() + return binary.LittleEndian.Uint64(ik[len(ik)-8:]) +} + +func (ik internalKey) parseNum() (seq uint64, kt keyType) { + num := ik.num() + seq, kt = uint64(num>>8), keyType(num&0xff) + if kt > keyTypeVal { + panic(fmt.Sprintf("leveldb: internal key %q, len=%d: invalid type %#x", []byte(ik), len(ik), kt)) + } + return +} + +func (ik internalKey) String() string { + if ik == nil { + return "" + } + + if ukey, seq, kt, err := parseInternalKey(ik); err == nil { + return fmt.Sprintf("%s,%s%d", shorten(string(ukey)), kt, seq) + } + return fmt.Sprintf("", []byte(ik)) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/key_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/key_test.go new file mode 100644 index 0000000..2f33ccb --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/key_test.go @@ -0,0 +1,133 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "testing" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +var defaultIComparer = &iComparer{comparer.DefaultComparer} + +func ikey(key string, seq uint64, kt keyType) internalKey { + return makeInternalKey(nil, []byte(key), uint64(seq), kt) +} + +func shortSep(a, b []byte) []byte { + dst := make([]byte, len(a)) + dst = defaultIComparer.Separator(dst[:0], a, b) + if dst == nil { + return a + } + return dst +} + +func shortSuccessor(b []byte) []byte { + dst := make([]byte, len(b)) + dst = defaultIComparer.Successor(dst[:0], b) + if dst == nil { + return b + } + return dst +} + +func testSingleKey(t *testing.T, key string, seq uint64, kt keyType) { + ik := ikey(key, seq, kt) + + if !bytes.Equal(ik.ukey(), []byte(key)) { + t.Errorf("user key does not equal, got %v, want %v", string(ik.ukey()), key) + } + + rseq, rt := ik.parseNum() + if rseq != seq { + t.Errorf("seq number does not equal, got %v, want %v", rseq, seq) + } + if rt != kt { + t.Errorf("type does not equal, got %v, want %v", rt, kt) + } + + if rukey, rseq, rt, kerr := parseInternalKey(ik); kerr == nil { + if !bytes.Equal(rukey, []byte(key)) { + t.Errorf("user key does not equal, got %v, want %v", string(ik.ukey()), key) + } + if rseq != seq { + t.Errorf("seq number does not equal, got %v, want %v", rseq, seq) + } + if rt != kt { + t.Errorf("type does not equal, got %v, want %v", rt, kt) + } + } else { + t.Errorf("key error: %v", kerr) + } +} + +func TestInternalKey_EncodeDecode(t *testing.T) { + keys := []string{"", "k", "hello", "longggggggggggggggggggggg"} + seqs := []uint64{ + 1, 2, 3, + (1 << 8) - 1, 1 << 8, (1 << 8) + 1, + (1 << 16) - 1, 1 << 16, (1 << 16) + 1, + (1 << 32) - 1, 1 << 32, (1 << 32) + 1, + } + for _, key := range keys { + for _, seq := range seqs { + testSingleKey(t, key, seq, keyTypeVal) + testSingleKey(t, "hello", 1, keyTypeDel) + } + } +} + +func assertBytes(t *testing.T, want, got []byte) { + if !bytes.Equal(got, want) { + t.Errorf("assert failed, got %v, want %v", got, want) + } +} + +func TestInternalKeyShortSeparator(t *testing.T) { + // When user keys are same + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("foo", 99, keyTypeVal))) + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("foo", 101, keyTypeVal))) + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("foo", 100, keyTypeVal))) + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("foo", 100, keyTypeDel))) + + // When user keys are misordered + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("bar", 99, keyTypeVal))) + + // When user keys are different, but correctly ordered + assertBytes(t, ikey("g", uint64(keyMaxSeq), keyTypeSeek), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("hello", 200, keyTypeVal))) + + // When start user key is prefix of limit user key + assertBytes(t, ikey("foo", 100, keyTypeVal), + shortSep(ikey("foo", 100, keyTypeVal), + ikey("foobar", 200, keyTypeVal))) + + // When limit user key is prefix of start user key + assertBytes(t, ikey("foobar", 100, keyTypeVal), + shortSep(ikey("foobar", 100, keyTypeVal), + ikey("foo", 200, keyTypeVal))) +} + +func TestInternalKeyShortestSuccessor(t *testing.T) { + assertBytes(t, ikey("g", uint64(keyMaxSeq), keyTypeSeek), + shortSuccessor(ikey("foo", 100, keyTypeVal))) + assertBytes(t, ikey("\xff\xff", 100, keyTypeVal), + shortSuccessor(ikey("\xff\xff", 100, keyTypeVal))) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go new file mode 100644 index 0000000..fefa007 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go @@ -0,0 +1,11 @@ +package leveldb + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestLevelDB(t *testing.T) { + testutil.RunSuite(t, "LevelDB Suite") +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go new file mode 100644 index 0000000..b05084c --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go @@ -0,0 +1,75 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package memdb + +import ( + "encoding/binary" + "math/rand" + "testing" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +func BenchmarkPut(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + b.ResetTimer() + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } +} + +func BenchmarkPutRandom(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int())) + } + + b.ResetTimer() + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } +} + +func BenchmarkGet(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } + + b.ResetTimer() + for i := range buf { + p.Get(buf[i][:]) + } +} + +func BenchmarkGetRandom(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p.Get(buf[rand.Int()%b.N][:]) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go new file mode 100644 index 0000000..824e47f --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go @@ -0,0 +1,479 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package memdb provides in-memory key/value database implementation. +package memdb + +import ( + "math/rand" + "sync" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Common errors. +var ( + ErrNotFound = errors.ErrNotFound + ErrIterReleased = errors.New("leveldb/memdb: iterator released") +) + +const tMaxHeight = 12 + +type dbIter struct { + util.BasicReleaser + p *DB + slice *util.Range + node int + forward bool + key, value []byte + err error +} + +func (i *dbIter) fill(checkStart, checkLimit bool) bool { + if i.node != 0 { + n := i.p.nodeData[i.node] + m := n + i.p.nodeData[i.node+nKey] + i.key = i.p.kvData[n:m] + if i.slice != nil { + switch { + case checkLimit && i.slice.Limit != nil && i.p.cmp.Compare(i.key, i.slice.Limit) >= 0: + fallthrough + case checkStart && i.slice.Start != nil && i.p.cmp.Compare(i.key, i.slice.Start) < 0: + i.node = 0 + goto bail + } + } + i.value = i.p.kvData[m : m+i.p.nodeData[i.node+nVal]] + return true + } +bail: + i.key = nil + i.value = nil + return false +} + +func (i *dbIter) Valid() bool { + return i.node != 0 +} + +func (i *dbIter) First() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Start != nil { + i.node, _ = i.p.findGE(i.slice.Start, false) + } else { + i.node = i.p.nodeData[nNext] + } + return i.fill(false, true) +} + +func (i *dbIter) Last() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = false + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Limit != nil { + i.node = i.p.findLT(i.slice.Limit) + } else { + i.node = i.p.findLast() + } + return i.fill(true, false) +} + +func (i *dbIter) Seek(key []byte) bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Start != nil && i.p.cmp.Compare(key, i.slice.Start) < 0 { + key = i.slice.Start + } + i.node, _ = i.p.findGE(key, false) + return i.fill(false, true) +} + +func (i *dbIter) Next() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.node == 0 { + if !i.forward { + return i.First() + } + return false + } + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + i.node = i.p.nodeData[i.node+nNext] + return i.fill(false, true) +} + +func (i *dbIter) Prev() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.node == 0 { + if i.forward { + return i.Last() + } + return false + } + i.forward = false + i.p.mu.RLock() + defer i.p.mu.RUnlock() + i.node = i.p.findLT(i.key) + return i.fill(true, false) +} + +func (i *dbIter) Key() []byte { + return i.key +} + +func (i *dbIter) Value() []byte { + return i.value +} + +func (i *dbIter) Error() error { return i.err } + +func (i *dbIter) Release() { + if !i.Released() { + i.p = nil + i.node = 0 + i.key = nil + i.value = nil + i.BasicReleaser.Release() + } +} + +const ( + nKV = iota + nKey + nVal + nHeight + nNext +) + +// DB is an in-memory key/value database. +type DB struct { + cmp comparer.BasicComparer + rnd *rand.Rand + + mu sync.RWMutex + kvData []byte + // Node data: + // [0] : KV offset + // [1] : Key length + // [2] : Value length + // [3] : Height + // [3..height] : Next nodes + nodeData []int + prevNode [tMaxHeight]int + maxHeight int + n int + kvSize int +} + +func (p *DB) randHeight() (h int) { + const branching = 4 + h = 1 + for h < tMaxHeight && p.rnd.Int()%branching == 0 { + h++ + } + return +} + +// Must hold RW-lock if prev == true, as it use shared prevNode slice. +func (p *DB) findGE(key []byte, prev bool) (int, bool) { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + cmp := 1 + if next != 0 { + o := p.nodeData[next] + cmp = p.cmp.Compare(p.kvData[o:o+p.nodeData[next+nKey]], key) + } + if cmp < 0 { + // Keep searching in this list + node = next + } else { + if prev { + p.prevNode[h] = node + } else if cmp == 0 { + return next, true + } + if h == 0 { + return next, cmp == 0 + } + h-- + } + } +} + +func (p *DB) findLT(key []byte) int { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + o := p.nodeData[next] + if next == 0 || p.cmp.Compare(p.kvData[o:o+p.nodeData[next+nKey]], key) >= 0 { + if h == 0 { + break + } + h-- + } else { + node = next + } + } + return node +} + +func (p *DB) findLast() int { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + if next == 0 { + if h == 0 { + break + } + h-- + } else { + node = next + } + } + return node +} + +// Put sets the value for the given key. It overwrites any previous value +// for that key; a DB is not a multi-map. +// +// It is safe to modify the contents of the arguments after Put returns. +func (p *DB) Put(key []byte, value []byte) error { + p.mu.Lock() + defer p.mu.Unlock() + + if node, exact := p.findGE(key, true); exact { + kvOffset := len(p.kvData) + p.kvData = append(p.kvData, key...) + p.kvData = append(p.kvData, value...) + p.nodeData[node] = kvOffset + m := p.nodeData[node+nVal] + p.nodeData[node+nVal] = len(value) + p.kvSize += len(value) - m + return nil + } + + h := p.randHeight() + if h > p.maxHeight { + for i := p.maxHeight; i < h; i++ { + p.prevNode[i] = 0 + } + p.maxHeight = h + } + + kvOffset := len(p.kvData) + p.kvData = append(p.kvData, key...) + p.kvData = append(p.kvData, value...) + // Node + node := len(p.nodeData) + p.nodeData = append(p.nodeData, kvOffset, len(key), len(value), h) + for i, n := range p.prevNode[:h] { + m := n + nNext + i + p.nodeData = append(p.nodeData, p.nodeData[m]) + p.nodeData[m] = node + } + + p.kvSize += len(key) + len(value) + p.n++ + return nil +} + +// Delete deletes the value for the given key. It returns ErrNotFound if +// the DB does not contain the key. +// +// It is safe to modify the contents of the arguments after Delete returns. +func (p *DB) Delete(key []byte) error { + p.mu.Lock() + defer p.mu.Unlock() + + node, exact := p.findGE(key, true) + if !exact { + return ErrNotFound + } + + h := p.nodeData[node+nHeight] + for i, n := range p.prevNode[:h] { + m := n + nNext + i + p.nodeData[m] = p.nodeData[p.nodeData[m]+nNext+i] + } + + p.kvSize -= p.nodeData[node+nKey] + p.nodeData[node+nVal] + p.n-- + return nil +} + +// Contains returns true if the given key are in the DB. +// +// It is safe to modify the contents of the arguments after Contains returns. +func (p *DB) Contains(key []byte) bool { + p.mu.RLock() + _, exact := p.findGE(key, false) + p.mu.RUnlock() + return exact +} + +// Get gets the value for the given key. It returns error.ErrNotFound if the +// DB does not contain the key. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Get returns. +func (p *DB) Get(key []byte) (value []byte, err error) { + p.mu.RLock() + if node, exact := p.findGE(key, false); exact { + o := p.nodeData[node] + p.nodeData[node+nKey] + value = p.kvData[o : o+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +// Find finds key/value pair whose key is greater than or equal to the +// given key. It returns ErrNotFound if the table doesn't contain +// such pair. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Find returns. +func (p *DB) Find(key []byte) (rkey, value []byte, err error) { + p.mu.RLock() + if node, _ := p.findGE(key, false); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +// NewIterator returns an iterator of the DB. +// The returned iterator is not safe for concurrent use, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. However, the resultant key/value pairs are not guaranteed +// to be a consistent snapshot of the DB at a particular point in time. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// +// The iterator must be released after use, by calling Release method. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (p *DB) NewIterator(slice *util.Range) iterator.Iterator { + return &dbIter{p: p, slice: slice} +} + +// Capacity returns keys/values buffer capacity. +func (p *DB) Capacity() int { + p.mu.RLock() + defer p.mu.RUnlock() + return cap(p.kvData) +} + +// Size returns sum of keys and values length. Note that deleted +// key/value will not be accounted for, but it will still consume +// the buffer, since the buffer is append only. +func (p *DB) Size() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.kvSize +} + +// Free returns keys/values free buffer before need to grow. +func (p *DB) Free() int { + p.mu.RLock() + defer p.mu.RUnlock() + return cap(p.kvData) - len(p.kvData) +} + +// Len returns the number of entries in the DB. +func (p *DB) Len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.n +} + +// Reset resets the DB to initial empty state. Allows reuse the buffer. +func (p *DB) Reset() { + p.mu.Lock() + p.rnd = rand.New(rand.NewSource(0xdeadbeef)) + p.maxHeight = 1 + p.n = 0 + p.kvSize = 0 + p.kvData = p.kvData[:0] + p.nodeData = p.nodeData[:nNext+tMaxHeight] + p.nodeData[nKV] = 0 + p.nodeData[nKey] = 0 + p.nodeData[nVal] = 0 + p.nodeData[nHeight] = tMaxHeight + for n := 0; n < tMaxHeight; n++ { + p.nodeData[nNext+n] = 0 + p.prevNode[n] = 0 + } + p.mu.Unlock() +} + +// New creates a new initialized in-memory key/value DB. The capacity +// is the initial key/value buffer capacity. The capacity is advisory, +// not enforced. +// +// This DB is append-only, deleting an entry would remove entry node but not +// reclaim KV buffer. +// +// The returned DB instance is safe for concurrent use. +func New(cmp comparer.BasicComparer, capacity int) *DB { + p := &DB{ + cmp: cmp, + rnd: rand.New(rand.NewSource(0xdeadbeef)), + maxHeight: 1, + kvData: make([]byte, 0, capacity), + nodeData: make([]int, 4+tMaxHeight), + } + p.nodeData[nHeight] = tMaxHeight + return p +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go new file mode 100644 index 0000000..18c304b --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go @@ -0,0 +1,11 @@ +package memdb + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestMemDB(t *testing.T) { + testutil.RunSuite(t, "MemDB Suite") +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go new file mode 100644 index 0000000..3f0a31e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go @@ -0,0 +1,135 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package memdb + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func (p *DB) TestFindLT(key []byte) (rkey, value []byte, err error) { + p.mu.RLock() + if node := p.findLT(key); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +func (p *DB) TestFindLast() (rkey, value []byte, err error) { + p.mu.RLock() + if node := p.findLast(); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +func (p *DB) TestPut(key []byte, value []byte) error { + p.Put(key, value) + return nil +} + +func (p *DB) TestDelete(key []byte) error { + p.Delete(key) + return nil +} + +func (p *DB) TestFind(key []byte) (rkey, rvalue []byte, err error) { + return p.Find(key) +} + +func (p *DB) TestGet(key []byte) (value []byte, err error) { + return p.Get(key) +} + +func (p *DB) TestNewIterator(slice *util.Range) iterator.Iterator { + return p.NewIterator(slice) +} + +var _ = testutil.Defer(func() { + Describe("Memdb", func() { + Describe("write test", func() { + It("should do write correctly", func() { + db := New(comparer.DefaultComparer, 0) + t := testutil.DBTesting{ + DB: db, + Deleted: testutil.KeyValue_Generate(nil, 1000, 1, 1, 30, 5, 5).Clone(), + PostFn: func(t *testutil.DBTesting) { + Expect(db.Len()).Should(Equal(t.Present.Len())) + Expect(db.Size()).Should(Equal(t.Present.Size())) + switch t.Act { + case testutil.DBPut, testutil.DBOverwrite: + Expect(db.Contains(t.ActKey)).Should(BeTrue()) + default: + Expect(db.Contains(t.ActKey)).Should(BeFalse()) + } + }, + } + testutil.DoDBTesting(&t) + }) + }) + + Describe("read test", func() { + testutil.AllKeyValueTesting(nil, func(kv testutil.KeyValue) testutil.DB { + // Building the DB. + db := New(comparer.DefaultComparer, 0) + kv.IterateShuffled(nil, func(i int, key, value []byte) { + db.Put(key, value) + }) + + if kv.Len() > 1 { + It("Should find correct keys with findLT", func() { + testutil.ShuffledIndex(nil, kv.Len()-1, 1, func(i int) { + key_, key, _ := kv.IndexInexact(i + 1) + expectedKey, expectedValue := kv.Index(i) + + // Using key that exist. + rkey, rvalue, err := db.TestFindLT(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q -> %q", key, expectedKey) + Expect(rkey).Should(Equal(expectedKey), "Key") + Expect(rvalue).Should(Equal(expectedValue), "Value for key %q -> %q", key, expectedKey) + + // Using key that doesn't exist. + rkey, rvalue, err = db.TestFindLT(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q (%q) -> %q", key_, key, expectedKey) + Expect(rkey).Should(Equal(expectedKey)) + Expect(rvalue).Should(Equal(expectedValue), "Value for key %q (%q) -> %q", key_, key, expectedKey) + }) + }) + } + + if kv.Len() > 0 { + It("Should find last key with findLast", func() { + key, value := kv.Index(kv.Len() - 1) + rkey, rvalue, err := db.TestFindLast() + Expect(err).ShouldNot(HaveOccurred()) + Expect(rkey).Should(Equal(key)) + Expect(rvalue).Should(Equal(value)) + }) + } + + return db + }, nil, nil) + }) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go b/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go new file mode 100644 index 0000000..c02c1e9 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go @@ -0,0 +1,716 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package opt provides sets of options used by LevelDB. +package opt + +import ( + "math" + + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/filter" +) + +const ( + KiB = 1024 + MiB = KiB * 1024 + GiB = MiB * 1024 +) + +var ( + DefaultBlockCacher = LRUCacher + DefaultBlockCacheCapacity = 8 * MiB + DefaultBlockRestartInterval = 16 + DefaultBlockSize = 4 * KiB + DefaultCompactionExpandLimitFactor = 25 + DefaultCompactionGPOverlapsFactor = 10 + DefaultCompactionL0Trigger = 4 + DefaultCompactionSourceLimitFactor = 1 + DefaultCompactionTableSize = 2 * MiB + DefaultCompactionTableSizeMultiplier = 1.0 + DefaultCompactionTotalSize = 10 * MiB + DefaultCompactionTotalSizeMultiplier = 10.0 + DefaultCompressionType = SnappyCompression + DefaultIteratorSamplingRate = 1 * MiB + DefaultOpenFilesCacher = LRUCacher + DefaultOpenFilesCacheCapacity = 500 + DefaultWriteBuffer = 4 * MiB + DefaultWriteL0PauseTrigger = 12 + DefaultWriteL0SlowdownTrigger = 8 +) + +// Cacher is a caching algorithm. +type Cacher interface { + New(capacity int) cache.Cacher +} + +type CacherFunc struct { + NewFunc func(capacity int) cache.Cacher +} + +func (f *CacherFunc) New(capacity int) cache.Cacher { + if f.NewFunc != nil { + return f.NewFunc(capacity) + } + return nil +} + +func noCacher(int) cache.Cacher { return nil } + +var ( + // LRUCacher is the LRU-cache algorithm. + LRUCacher = &CacherFunc{cache.NewLRU} + + // NoCacher is the value to disable caching algorithm. + NoCacher = &CacherFunc{} +) + +// Compression is the 'sorted table' block compression algorithm to use. +type Compression uint + +func (c Compression) String() string { + switch c { + case DefaultCompression: + return "default" + case NoCompression: + return "none" + case SnappyCompression: + return "snappy" + } + return "invalid" +} + +const ( + DefaultCompression Compression = iota + NoCompression + SnappyCompression + nCompression +) + +// Strict is the DB 'strict level'. +type Strict uint + +const ( + // If present then a corrupted or invalid chunk or block in manifest + // journal will cause an error instead of being dropped. + // This will prevent database with corrupted manifest to be opened. + StrictManifest Strict = 1 << iota + + // If present then journal chunk checksum will be verified. + StrictJournalChecksum + + // If present then a corrupted or invalid chunk or block in journal + // will cause an error instead of being dropped. + // This will prevent database with corrupted journal to be opened. + StrictJournal + + // If present then 'sorted table' block checksum will be verified. + // This has effect on both 'read operation' and compaction. + StrictBlockChecksum + + // If present then a corrupted 'sorted table' will fails compaction. + // The database will enter read-only mode. + StrictCompaction + + // If present then a corrupted 'sorted table' will halts 'read operation'. + StrictReader + + // If present then leveldb.Recover will drop corrupted 'sorted table'. + StrictRecovery + + // This only applicable for ReadOptions, if present then this ReadOptions + // 'strict level' will override global ones. + StrictOverride + + // StrictAll enables all strict flags. + StrictAll = StrictManifest | StrictJournalChecksum | StrictJournal | StrictBlockChecksum | StrictCompaction | StrictReader | StrictRecovery + + // DefaultStrict is the default strict flags. Specify any strict flags + // will override default strict flags as whole (i.e. not OR'ed). + DefaultStrict = StrictJournalChecksum | StrictBlockChecksum | StrictCompaction | StrictReader + + // NoStrict disables all strict flags. Override default strict flags. + NoStrict = ^StrictAll +) + +// Options holds the optional parameters for the DB at large. +type Options struct { + // AltFilters defines one or more 'alternative filters'. + // 'alternative filters' will be used during reads if a filter block + // does not match with the 'effective filter'. + // + // The default value is nil + AltFilters []filter.Filter + + // BlockCacher provides cache algorithm for LevelDB 'sorted table' block caching. + // Specify NoCacher to disable caching algorithm. + // + // The default value is LRUCacher. + BlockCacher Cacher + + // BlockCacheCapacity defines the capacity of the 'sorted table' block caching. + // Use -1 for zero, this has same effect as specifying NoCacher to BlockCacher. + // + // The default value is 8MiB. + BlockCacheCapacity int + + // BlockCacheEvictRemoved allows enable forced-eviction on cached block belonging + // to removed 'sorted table'. + // + // The default if false. + BlockCacheEvictRemoved bool + + // BlockRestartInterval is the number of keys between restart points for + // delta encoding of keys. + // + // The default value is 16. + BlockRestartInterval int + + // BlockSize is the minimum uncompressed size in bytes of each 'sorted table' + // block. + // + // The default value is 4KiB. + BlockSize int + + // CompactionExpandLimitFactor limits compaction size after expanded. + // This will be multiplied by table size limit at compaction target level. + // + // The default value is 25. + CompactionExpandLimitFactor int + + // CompactionGPOverlapsFactor limits overlaps in grandparent (Level + 2) that a + // single 'sorted table' generates. + // This will be multiplied by table size limit at grandparent level. + // + // The default value is 10. + CompactionGPOverlapsFactor int + + // CompactionL0Trigger defines number of 'sorted table' at level-0 that will + // trigger compaction. + // + // The default value is 4. + CompactionL0Trigger int + + // CompactionSourceLimitFactor limits compaction source size. This doesn't apply to + // level-0. + // This will be multiplied by table size limit at compaction target level. + // + // The default value is 1. + CompactionSourceLimitFactor int + + // CompactionTableSize limits size of 'sorted table' that compaction generates. + // The limits for each level will be calculated as: + // CompactionTableSize * (CompactionTableSizeMultiplier ^ Level) + // The multiplier for each level can also fine-tuned using CompactionTableSizeMultiplierPerLevel. + // + // The default value is 2MiB. + CompactionTableSize int + + // CompactionTableSizeMultiplier defines multiplier for CompactionTableSize. + // + // The default value is 1. + CompactionTableSizeMultiplier float64 + + // CompactionTableSizeMultiplierPerLevel defines per-level multiplier for + // CompactionTableSize. + // Use zero to skip a level. + // + // The default value is nil. + CompactionTableSizeMultiplierPerLevel []float64 + + // CompactionTotalSize limits total size of 'sorted table' for each level. + // The limits for each level will be calculated as: + // CompactionTotalSize * (CompactionTotalSizeMultiplier ^ Level) + // The multiplier for each level can also fine-tuned using + // CompactionTotalSizeMultiplierPerLevel. + // + // The default value is 10MiB. + CompactionTotalSize int + + // CompactionTotalSizeMultiplier defines multiplier for CompactionTotalSize. + // + // The default value is 10. + CompactionTotalSizeMultiplier float64 + + // CompactionTotalSizeMultiplierPerLevel defines per-level multiplier for + // CompactionTotalSize. + // Use zero to skip a level. + // + // The default value is nil. + CompactionTotalSizeMultiplierPerLevel []float64 + + // Comparer defines a total ordering over the space of []byte keys: a 'less + // than' relationship. The same comparison algorithm must be used for reads + // and writes over the lifetime of the DB. + // + // The default value uses the same ordering as bytes.Compare. + Comparer comparer.Comparer + + // Compression defines the 'sorted table' block compression to use. + // + // The default value (DefaultCompression) uses snappy compression. + Compression Compression + + // DisableBufferPool allows disable use of util.BufferPool functionality. + // + // The default value is false. + DisableBufferPool bool + + // DisableBlockCache allows disable use of cache.Cache functionality on + // 'sorted table' block. + // + // The default value is false. + DisableBlockCache bool + + // DisableCompactionBackoff allows disable compaction retry backoff. + // + // The default value is false. + DisableCompactionBackoff bool + + // DisableLargeBatchTransaction allows disabling switch-to-transaction mode + // on large batch write. If enable batch writes large than WriteBuffer will + // use transaction. + // + // The default is false. + DisableLargeBatchTransaction bool + + // DisableSeeksCompaction allows disabling 'seeks triggered compaction'. + // The purpose of 'seeks triggered compaction' is to optimize database so + // that 'level seeks' can be minimized, however this might generate many + // small compaction which may not preferable. + // + // The default is false. + DisableSeeksCompaction bool + + // ErrorIfExist defines whether an error should returned if the DB already + // exist. + // + // The default value is false. + ErrorIfExist bool + + // ErrorIfMissing defines whether an error should returned if the DB is + // missing. If false then the database will be created if missing, otherwise + // an error will be returned. + // + // The default value is false. + ErrorIfMissing bool + + // Filter defines an 'effective filter' to use. An 'effective filter' + // if defined will be used to generate per-table filter block. + // The filter name will be stored on disk. + // During reads LevelDB will try to find matching filter from + // 'effective filter' and 'alternative filters'. + // + // Filter can be changed after a DB has been created. It is recommended + // to put old filter to the 'alternative filters' to mitigate lack of + // filter during transition period. + // + // A filter is used to reduce disk reads when looking for a specific key. + // + // The default value is nil. + Filter filter.Filter + + // IteratorSamplingRate defines approximate gap (in bytes) between read + // sampling of an iterator. The samples will be used to determine when + // compaction should be triggered. + // Use negative value to disable iterator sampling. + // The iterator sampling is disabled if DisableSeeksCompaction is true. + // + // The default is 1MiB. + IteratorSamplingRate int + + // NoSync allows completely disable fsync. + // + // The default is false. + NoSync bool + + // NoWriteMerge allows disabling write merge. + // + // The default is false. + NoWriteMerge bool + + // OpenFilesCacher provides cache algorithm for open files caching. + // Specify NoCacher to disable caching algorithm. + // + // The default value is LRUCacher. + OpenFilesCacher Cacher + + // OpenFilesCacheCapacity defines the capacity of the open files caching. + // Use -1 for zero, this has same effect as specifying NoCacher to OpenFilesCacher. + // + // The default value is 500. + OpenFilesCacheCapacity int + + // If true then opens DB in read-only mode. + // + // The default value is false. + ReadOnly bool + + // Strict defines the DB strict level. + Strict Strict + + // WriteBuffer defines maximum size of a 'memdb' before flushed to + // 'sorted table'. 'memdb' is an in-memory DB backed by an on-disk + // unsorted journal. + // + // LevelDB may held up to two 'memdb' at the same time. + // + // The default value is 4MiB. + WriteBuffer int + + // WriteL0StopTrigger defines number of 'sorted table' at level-0 that will + // pause write. + // + // The default value is 12. + WriteL0PauseTrigger int + + // WriteL0SlowdownTrigger defines number of 'sorted table' at level-0 that + // will trigger write slowdown. + // + // The default value is 8. + WriteL0SlowdownTrigger int +} + +func (o *Options) GetAltFilters() []filter.Filter { + if o == nil { + return nil + } + return o.AltFilters +} + +func (o *Options) GetBlockCacher() Cacher { + if o == nil || o.BlockCacher == nil { + return DefaultBlockCacher + } else if o.BlockCacher == NoCacher { + return nil + } + return o.BlockCacher +} + +func (o *Options) GetBlockCacheCapacity() int { + if o == nil || o.BlockCacheCapacity == 0 { + return DefaultBlockCacheCapacity + } else if o.BlockCacheCapacity < 0 { + return 0 + } + return o.BlockCacheCapacity +} + +func (o *Options) GetBlockCacheEvictRemoved() bool { + if o == nil { + return false + } + return o.BlockCacheEvictRemoved +} + +func (o *Options) GetBlockRestartInterval() int { + if o == nil || o.BlockRestartInterval <= 0 { + return DefaultBlockRestartInterval + } + return o.BlockRestartInterval +} + +func (o *Options) GetBlockSize() int { + if o == nil || o.BlockSize <= 0 { + return DefaultBlockSize + } + return o.BlockSize +} + +func (o *Options) GetCompactionExpandLimit(level int) int { + factor := DefaultCompactionExpandLimitFactor + if o != nil && o.CompactionExpandLimitFactor > 0 { + factor = o.CompactionExpandLimitFactor + } + return o.GetCompactionTableSize(level+1) * factor +} + +func (o *Options) GetCompactionGPOverlaps(level int) int { + factor := DefaultCompactionGPOverlapsFactor + if o != nil && o.CompactionGPOverlapsFactor > 0 { + factor = o.CompactionGPOverlapsFactor + } + return o.GetCompactionTableSize(level+2) * factor +} + +func (o *Options) GetCompactionL0Trigger() int { + if o == nil || o.CompactionL0Trigger == 0 { + return DefaultCompactionL0Trigger + } + return o.CompactionL0Trigger +} + +func (o *Options) GetCompactionSourceLimit(level int) int { + factor := DefaultCompactionSourceLimitFactor + if o != nil && o.CompactionSourceLimitFactor > 0 { + factor = o.CompactionSourceLimitFactor + } + return o.GetCompactionTableSize(level+1) * factor +} + +func (o *Options) GetCompactionTableSize(level int) int { + var ( + base = DefaultCompactionTableSize + mult float64 + ) + if o != nil { + if o.CompactionTableSize > 0 { + base = o.CompactionTableSize + } + if level < len(o.CompactionTableSizeMultiplierPerLevel) && o.CompactionTableSizeMultiplierPerLevel[level] > 0 { + mult = o.CompactionTableSizeMultiplierPerLevel[level] + } else if o.CompactionTableSizeMultiplier > 0 { + mult = math.Pow(o.CompactionTableSizeMultiplier, float64(level)) + } + } + if mult == 0 { + mult = math.Pow(DefaultCompactionTableSizeMultiplier, float64(level)) + } + return int(float64(base) * mult) +} + +func (o *Options) GetCompactionTotalSize(level int) int64 { + var ( + base = DefaultCompactionTotalSize + mult float64 + ) + if o != nil { + if o.CompactionTotalSize > 0 { + base = o.CompactionTotalSize + } + if level < len(o.CompactionTotalSizeMultiplierPerLevel) && o.CompactionTotalSizeMultiplierPerLevel[level] > 0 { + mult = o.CompactionTotalSizeMultiplierPerLevel[level] + } else if o.CompactionTotalSizeMultiplier > 0 { + mult = math.Pow(o.CompactionTotalSizeMultiplier, float64(level)) + } + } + if mult == 0 { + mult = math.Pow(DefaultCompactionTotalSizeMultiplier, float64(level)) + } + return int64(float64(base) * mult) +} + +func (o *Options) GetComparer() comparer.Comparer { + if o == nil || o.Comparer == nil { + return comparer.DefaultComparer + } + return o.Comparer +} + +func (o *Options) GetCompression() Compression { + if o == nil || o.Compression <= DefaultCompression || o.Compression >= nCompression { + return DefaultCompressionType + } + return o.Compression +} + +func (o *Options) GetDisableBufferPool() bool { + if o == nil { + return false + } + return o.DisableBufferPool +} + +func (o *Options) GetDisableBlockCache() bool { + if o == nil { + return false + } + return o.DisableBlockCache +} + +func (o *Options) GetDisableCompactionBackoff() bool { + if o == nil { + return false + } + return o.DisableCompactionBackoff +} + +func (o *Options) GetDisableLargeBatchTransaction() bool { + if o == nil { + return false + } + return o.DisableLargeBatchTransaction +} + +func (o *Options) GetDisableSeeksCompaction() bool { + if o == nil { + return false + } + return o.DisableSeeksCompaction +} + +func (o *Options) GetErrorIfExist() bool { + if o == nil { + return false + } + return o.ErrorIfExist +} + +func (o *Options) GetErrorIfMissing() bool { + if o == nil { + return false + } + return o.ErrorIfMissing +} + +func (o *Options) GetFilter() filter.Filter { + if o == nil { + return nil + } + return o.Filter +} + +func (o *Options) GetIteratorSamplingRate() int { + if o == nil || o.IteratorSamplingRate == 0 { + return DefaultIteratorSamplingRate + } else if o.IteratorSamplingRate < 0 { + return 0 + } + return o.IteratorSamplingRate +} + +func (o *Options) GetNoSync() bool { + if o == nil { + return false + } + return o.NoSync +} + +func (o *Options) GetNoWriteMerge() bool { + if o == nil { + return false + } + return o.NoWriteMerge +} + +func (o *Options) GetOpenFilesCacher() Cacher { + if o == nil || o.OpenFilesCacher == nil { + return DefaultOpenFilesCacher + } + if o.OpenFilesCacher == NoCacher { + return nil + } + return o.OpenFilesCacher +} + +func (o *Options) GetOpenFilesCacheCapacity() int { + if o == nil || o.OpenFilesCacheCapacity == 0 { + return DefaultOpenFilesCacheCapacity + } else if o.OpenFilesCacheCapacity < 0 { + return 0 + } + return o.OpenFilesCacheCapacity +} + +func (o *Options) GetReadOnly() bool { + if o == nil { + return false + } + return o.ReadOnly +} + +func (o *Options) GetStrict(strict Strict) bool { + if o == nil || o.Strict == 0 { + return DefaultStrict&strict != 0 + } + return o.Strict&strict != 0 +} + +func (o *Options) GetWriteBuffer() int { + if o == nil || o.WriteBuffer <= 0 { + return DefaultWriteBuffer + } + return o.WriteBuffer +} + +func (o *Options) GetWriteL0PauseTrigger() int { + if o == nil || o.WriteL0PauseTrigger == 0 { + return DefaultWriteL0PauseTrigger + } + return o.WriteL0PauseTrigger +} + +func (o *Options) GetWriteL0SlowdownTrigger() int { + if o == nil || o.WriteL0SlowdownTrigger == 0 { + return DefaultWriteL0SlowdownTrigger + } + return o.WriteL0SlowdownTrigger +} + +// ReadOptions holds the optional parameters for 'read operation'. The +// 'read operation' includes Get, Find and NewIterator. +type ReadOptions struct { + // DontFillCache defines whether block reads for this 'read operation' + // should be cached. If false then the block will be cached. This does + // not affects already cached block. + // + // The default value is false. + DontFillCache bool + + // Strict will be OR'ed with global DB 'strict level' unless StrictOverride + // is present. Currently only StrictReader that has effect here. + Strict Strict +} + +func (ro *ReadOptions) GetDontFillCache() bool { + if ro == nil { + return false + } + return ro.DontFillCache +} + +func (ro *ReadOptions) GetStrict(strict Strict) bool { + if ro == nil { + return false + } + return ro.Strict&strict != 0 +} + +// WriteOptions holds the optional parameters for 'write operation'. The +// 'write operation' includes Write, Put and Delete. +type WriteOptions struct { + // NoWriteMerge allows disabling write merge. + // + // The default is false. + NoWriteMerge bool + + // Sync is whether to sync underlying writes from the OS buffer cache + // through to actual disk, if applicable. Setting Sync can result in + // slower writes. + // + // If false, and the machine crashes, then some recent writes may be lost. + // Note that if it is just the process that crashes (and the machine does + // not) then no writes will be lost. + // + // In other words, Sync being false has the same semantics as a write + // system call. Sync being true means write followed by fsync. + // + // The default value is false. + Sync bool +} + +func (wo *WriteOptions) GetNoWriteMerge() bool { + if wo == nil { + return false + } + return wo.NoWriteMerge +} + +func (wo *WriteOptions) GetSync() bool { + if wo == nil { + return false + } + return wo.Sync +} + +func GetStrict(o *Options, ro *ReadOptions, strict Strict) bool { + if ro.GetStrict(StrictOverride) { + return ro.GetStrict(strict) + } else { + return o.GetStrict(strict) || ro.GetStrict(strict) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/options.go b/vendor/github.com/syndtr/goleveldb/leveldb/options.go new file mode 100644 index 0000000..b072b1a --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/options.go @@ -0,0 +1,107 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" +) + +func dupOptions(o *opt.Options) *opt.Options { + newo := &opt.Options{} + if o != nil { + *newo = *o + } + if newo.Strict == 0 { + newo.Strict = opt.DefaultStrict + } + return newo +} + +func (s *session) setOptions(o *opt.Options) { + no := dupOptions(o) + // Alternative filters. + if filters := o.GetAltFilters(); len(filters) > 0 { + no.AltFilters = make([]filter.Filter, len(filters)) + for i, filter := range filters { + no.AltFilters[i] = &iFilter{filter} + } + } + // Comparer. + s.icmp = &iComparer{o.GetComparer()} + no.Comparer = s.icmp + // Filter. + if filter := o.GetFilter(); filter != nil { + no.Filter = &iFilter{filter} + } + + s.o = &cachedOptions{Options: no} + s.o.cache() +} + +const optCachedLevel = 7 + +type cachedOptions struct { + *opt.Options + + compactionExpandLimit []int + compactionGPOverlaps []int + compactionSourceLimit []int + compactionTableSize []int + compactionTotalSize []int64 +} + +func (co *cachedOptions) cache() { + co.compactionExpandLimit = make([]int, optCachedLevel) + co.compactionGPOverlaps = make([]int, optCachedLevel) + co.compactionSourceLimit = make([]int, optCachedLevel) + co.compactionTableSize = make([]int, optCachedLevel) + co.compactionTotalSize = make([]int64, optCachedLevel) + + for level := 0; level < optCachedLevel; level++ { + co.compactionExpandLimit[level] = co.Options.GetCompactionExpandLimit(level) + co.compactionGPOverlaps[level] = co.Options.GetCompactionGPOverlaps(level) + co.compactionSourceLimit[level] = co.Options.GetCompactionSourceLimit(level) + co.compactionTableSize[level] = co.Options.GetCompactionTableSize(level) + co.compactionTotalSize[level] = co.Options.GetCompactionTotalSize(level) + } +} + +func (co *cachedOptions) GetCompactionExpandLimit(level int) int { + if level < optCachedLevel { + return co.compactionExpandLimit[level] + } + return co.Options.GetCompactionExpandLimit(level) +} + +func (co *cachedOptions) GetCompactionGPOverlaps(level int) int { + if level < optCachedLevel { + return co.compactionGPOverlaps[level] + } + return co.Options.GetCompactionGPOverlaps(level) +} + +func (co *cachedOptions) GetCompactionSourceLimit(level int) int { + if level < optCachedLevel { + return co.compactionSourceLimit[level] + } + return co.Options.GetCompactionSourceLimit(level) +} + +func (co *cachedOptions) GetCompactionTableSize(level int) int { + if level < optCachedLevel { + return co.compactionTableSize[level] + } + return co.Options.GetCompactionTableSize(level) +} + +func (co *cachedOptions) GetCompactionTotalSize(level int) int64 { + if level < optCachedLevel { + return co.compactionTotalSize[level] + } + return co.Options.GetCompactionTotalSize(level) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session.go b/vendor/github.com/syndtr/goleveldb/leveldb/session.go new file mode 100644 index 0000000..7310209 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session.go @@ -0,0 +1,239 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "io" + "os" + "sync" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +// ErrManifestCorrupted records manifest corruption. This error will be +// wrapped with errors.ErrCorrupted. +type ErrManifestCorrupted struct { + Field string + Reason string +} + +func (e *ErrManifestCorrupted) Error() string { + return fmt.Sprintf("leveldb: manifest corrupted (field '%s'): %s", e.Field, e.Reason) +} + +func newErrManifestCorrupted(fd storage.FileDesc, field, reason string) error { + return errors.NewErrCorrupted(fd, &ErrManifestCorrupted{field, reason}) +} + +// session represent a persistent database session. +type session struct { + // Need 64-bit alignment. + stNextFileNum int64 // current unused file number + stJournalNum int64 // current journal file number; need external synchronization + stPrevJournalNum int64 // prev journal file number; no longer used; for compatibility with older version of leveldb + stTempFileNum int64 + stSeqNum uint64 // last mem compacted seq; need external synchronization + + stor *iStorage + storLock storage.Locker + o *cachedOptions + icmp *iComparer + tops *tOps + + manifest *journal.Writer + manifestWriter storage.Writer + manifestFd storage.FileDesc + + stCompPtrs []internalKey // compaction pointers; need external synchronization + stVersion *version // current version + ntVersionId int64 // next version id to assign + refCh chan *vTask + relCh chan *vTask + deltaCh chan *vDelta + abandon chan int64 + closeC chan struct{} + closeW sync.WaitGroup + vmu sync.Mutex + + // Testing fields + fileRefCh chan chan map[int64]int // channel used to pass current reference stat +} + +// Creates new initialized session instance. +func newSession(stor storage.Storage, o *opt.Options) (s *session, err error) { + if stor == nil { + return nil, os.ErrInvalid + } + storLock, err := stor.Lock() + if err != nil { + return + } + s = &session{ + stor: newIStorage(stor), + storLock: storLock, + refCh: make(chan *vTask), + relCh: make(chan *vTask), + deltaCh: make(chan *vDelta), + abandon: make(chan int64), + fileRefCh: make(chan chan map[int64]int), + closeC: make(chan struct{}), + } + s.setOptions(o) + s.tops = newTableOps(s) + + s.closeW.Add(1) + go s.refLoop() + s.setVersion(nil, newVersion(s)) + s.log("log@legend F·NumFile S·FileSize N·Entry C·BadEntry B·BadBlock Ke·KeyError D·DroppedEntry L·Level Q·SeqNum T·TimeElapsed") + return +} + +// Close session. +func (s *session) close() { + s.tops.close() + if s.manifest != nil { + s.manifest.Close() + } + if s.manifestWriter != nil { + s.manifestWriter.Close() + } + s.manifest = nil + s.manifestWriter = nil + s.setVersion(nil, &version{s: s, closing: true, id: s.ntVersionId}) + + // Close all background goroutines + close(s.closeC) + s.closeW.Wait() +} + +// Release session lock. +func (s *session) release() { + s.storLock.Unlock() +} + +// Create a new database session; need external synchronization. +func (s *session) create() error { + // create manifest + return s.newManifest(nil, nil) +} + +// Recover a database session; need external synchronization. +func (s *session) recover() (err error) { + defer func() { + if os.IsNotExist(err) { + // Don't return os.ErrNotExist if the underlying storage contains + // other files that belong to LevelDB. So the DB won't get trashed. + if fds, _ := s.stor.List(storage.TypeAll); len(fds) > 0 { + err = &errors.ErrCorrupted{Fd: storage.FileDesc{Type: storage.TypeManifest}, Err: &errors.ErrMissingFiles{}} + } + } + }() + + fd, err := s.stor.GetMeta() + if err != nil { + return + } + + reader, err := s.stor.Open(fd) + if err != nil { + return + } + defer reader.Close() + + var ( + // Options. + strict = s.o.GetStrict(opt.StrictManifest) + + jr = journal.NewReader(reader, dropper{s, fd}, strict, true) + rec = &sessionRecord{} + staging = s.stVersion.newStaging() + ) + for { + var r io.Reader + r, err = jr.Next() + if err != nil { + if err == io.EOF { + err = nil + break + } + return errors.SetFd(err, fd) + } + + err = rec.decode(r) + if err == nil { + // save compact pointers + for _, r := range rec.compPtrs { + s.setCompPtr(r.level, internalKey(r.ikey)) + } + // commit record to version staging + staging.commit(rec) + } else { + err = errors.SetFd(err, fd) + if strict || !errors.IsCorrupted(err) { + return + } + s.logf("manifest error: %v (skipped)", errors.SetFd(err, fd)) + } + rec.resetCompPtrs() + rec.resetAddedTables() + rec.resetDeletedTables() + } + + switch { + case !rec.has(recComparer): + return newErrManifestCorrupted(fd, "comparer", "missing") + case rec.comparer != s.icmp.uName(): + return newErrManifestCorrupted(fd, "comparer", fmt.Sprintf("mismatch: want '%s', got '%s'", s.icmp.uName(), rec.comparer)) + case !rec.has(recNextFileNum): + return newErrManifestCorrupted(fd, "next-file-num", "missing") + case !rec.has(recJournalNum): + return newErrManifestCorrupted(fd, "journal-file-num", "missing") + case !rec.has(recSeqNum): + return newErrManifestCorrupted(fd, "seq-num", "missing") + } + + s.manifestFd = fd + s.setVersion(rec, staging.finish(false)) + s.setNextFileNum(rec.nextFileNum) + s.recordCommited(rec) + return nil +} + +// Commit session; need external synchronization. +func (s *session) commit(r *sessionRecord, trivial bool) (err error) { + v := s.version() + defer v.release() + + // spawn new version based on current version + nv := v.spawn(r, trivial) + + // abandon useless version id to prevent blocking version processing loop. + defer func() { + if err != nil { + s.abandon <- nv.id + s.logf("commit@abandon useless vid D%d", nv.id) + } + }() + + if s.manifest == nil { + // manifest journal writer not yet created, create one + err = s.newManifest(r, nv) + } else { + err = s.flushManifest(r) + } + + // finally, apply new version if no error rise + if err == nil { + s.setVersion(r, nv) + } + + return +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go new file mode 100644 index 0000000..4c1d336 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go @@ -0,0 +1,326 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" +) + +const ( + undefinedCompaction = iota + level0Compaction + nonLevel0Compaction + seekCompaction +) + +func (s *session) pickMemdbLevel(umin, umax []byte, maxLevel int) int { + v := s.version() + defer v.release() + return v.pickMemdbLevel(umin, umax, maxLevel) +} + +func (s *session) flushMemdb(rec *sessionRecord, mdb *memdb.DB, maxLevel int) (int, error) { + // Create sorted table. + iter := mdb.NewIterator(nil) + defer iter.Release() + t, n, err := s.tops.createFrom(iter) + if err != nil { + return 0, err + } + + // Pick level other than zero can cause compaction issue with large + // bulk insert and delete on strictly incrementing key-space. The + // problem is that the small deletion markers trapped at lower level, + // while key/value entries keep growing at higher level. Since the + // key-space is strictly incrementing it will not overlaps with + // higher level, thus maximum possible level is always picked, while + // overlapping deletion marker pushed into lower level. + // See: https://github.com/syndtr/goleveldb/issues/127. + flushLevel := s.pickMemdbLevel(t.imin.ukey(), t.imax.ukey(), maxLevel) + rec.addTableFile(flushLevel, t) + + s.logf("memdb@flush created L%d@%d N·%d S·%s %q:%q", flushLevel, t.fd.Num, n, shortenb(int(t.size)), t.imin, t.imax) + return flushLevel, nil +} + +// Pick a compaction based on current state; need external synchronization. +func (s *session) pickCompaction() *compaction { + v := s.version() + + var sourceLevel int + var t0 tFiles + var typ int + if v.cScore >= 1 { + sourceLevel = v.cLevel + cptr := s.getCompPtr(sourceLevel) + tables := v.levels[sourceLevel] + for _, t := range tables { + if cptr == nil || s.icmp.Compare(t.imax, cptr) > 0 { + t0 = append(t0, t) + break + } + } + if len(t0) == 0 { + t0 = append(t0, tables[0]) + } + if sourceLevel == 0 { + typ = level0Compaction + } else { + typ = nonLevel0Compaction + } + } else { + if p := atomic.LoadPointer(&v.cSeek); p != nil { + ts := (*tSet)(p) + sourceLevel = ts.level + t0 = append(t0, ts.table) + typ = seekCompaction + } else { + v.release() + return nil + } + } + + return newCompaction(s, v, sourceLevel, t0, typ) +} + +// Create compaction from given level and range; need external synchronization. +func (s *session) getCompactionRange(sourceLevel int, umin, umax []byte, noLimit bool) *compaction { + v := s.version() + + if sourceLevel >= len(v.levels) { + v.release() + return nil + } + + t0 := v.levels[sourceLevel].getOverlaps(nil, s.icmp, umin, umax, sourceLevel == 0) + if len(t0) == 0 { + v.release() + return nil + } + + // Avoid compacting too much in one shot in case the range is large. + // But we cannot do this for level-0 since level-0 files can overlap + // and we must not pick one file and drop another older file if the + // two files overlap. + if !noLimit && sourceLevel > 0 { + limit := int64(v.s.o.GetCompactionSourceLimit(sourceLevel)) + total := int64(0) + for i, t := range t0 { + total += t.size + if total >= limit { + s.logf("table@compaction limiting F·%d -> F·%d", len(t0), i+1) + t0 = t0[:i+1] + break + } + } + } + + typ := level0Compaction + if sourceLevel != 0 { + typ = nonLevel0Compaction + } + return newCompaction(s, v, sourceLevel, t0, typ) +} + +func newCompaction(s *session, v *version, sourceLevel int, t0 tFiles, typ int) *compaction { + c := &compaction{ + s: s, + v: v, + typ: typ, + sourceLevel: sourceLevel, + levels: [2]tFiles{t0, nil}, + maxGPOverlaps: int64(s.o.GetCompactionGPOverlaps(sourceLevel)), + tPtrs: make([]int, len(v.levels)), + } + c.expand() + c.save() + return c +} + +// compaction represent a compaction state. +type compaction struct { + s *session + v *version + + typ int + sourceLevel int + levels [2]tFiles + maxGPOverlaps int64 + + gp tFiles + gpi int + seenKey bool + gpOverlappedBytes int64 + imin, imax internalKey + tPtrs []int + released bool + + snapGPI int + snapSeenKey bool + snapGPOverlappedBytes int64 + snapTPtrs []int +} + +func (c *compaction) save() { + c.snapGPI = c.gpi + c.snapSeenKey = c.seenKey + c.snapGPOverlappedBytes = c.gpOverlappedBytes + c.snapTPtrs = append(c.snapTPtrs[:0], c.tPtrs...) +} + +func (c *compaction) restore() { + c.gpi = c.snapGPI + c.seenKey = c.snapSeenKey + c.gpOverlappedBytes = c.snapGPOverlappedBytes + c.tPtrs = append(c.tPtrs[:0], c.snapTPtrs...) +} + +func (c *compaction) release() { + if !c.released { + c.released = true + c.v.release() + } +} + +// Expand compacted tables; need external synchronization. +func (c *compaction) expand() { + limit := int64(c.s.o.GetCompactionExpandLimit(c.sourceLevel)) + vt0 := c.v.levels[c.sourceLevel] + vt1 := tFiles{} + if level := c.sourceLevel + 1; level < len(c.v.levels) { + vt1 = c.v.levels[level] + } + + t0, t1 := c.levels[0], c.levels[1] + imin, imax := t0.getRange(c.s.icmp) + + // For non-zero levels, the ukey can't hop across tables at all. + if c.sourceLevel == 0 { + // We expand t0 here just incase ukey hop across tables. + t0 = vt0.getOverlaps(t0, c.s.icmp, imin.ukey(), imax.ukey(), c.sourceLevel == 0) + if len(t0) != len(c.levels[0]) { + imin, imax = t0.getRange(c.s.icmp) + } + } + t1 = vt1.getOverlaps(t1, c.s.icmp, imin.ukey(), imax.ukey(), false) + // Get entire range covered by compaction. + amin, amax := append(t0, t1...).getRange(c.s.icmp) + + // See if we can grow the number of inputs in "sourceLevel" without + // changing the number of "sourceLevel+1" files we pick up. + if len(t1) > 0 { + exp0 := vt0.getOverlaps(nil, c.s.icmp, amin.ukey(), amax.ukey(), c.sourceLevel == 0) + if len(exp0) > len(t0) && t1.size()+exp0.size() < limit { + xmin, xmax := exp0.getRange(c.s.icmp) + exp1 := vt1.getOverlaps(nil, c.s.icmp, xmin.ukey(), xmax.ukey(), false) + if len(exp1) == len(t1) { + c.s.logf("table@compaction expanding L%d+L%d (F·%d S·%s)+(F·%d S·%s) -> (F·%d S·%s)+(F·%d S·%s)", + c.sourceLevel, c.sourceLevel+1, len(t0), shortenb(int(t0.size())), len(t1), shortenb(int(t1.size())), + len(exp0), shortenb(int(exp0.size())), len(exp1), shortenb(int(exp1.size()))) + imin, imax = xmin, xmax + t0, t1 = exp0, exp1 + amin, amax = append(t0, t1...).getRange(c.s.icmp) + } + } + } + + // Compute the set of grandparent files that overlap this compaction + // (parent == sourceLevel+1; grandparent == sourceLevel+2) + if level := c.sourceLevel + 2; level < len(c.v.levels) { + c.gp = c.v.levels[level].getOverlaps(c.gp, c.s.icmp, amin.ukey(), amax.ukey(), false) + } + + c.levels[0], c.levels[1] = t0, t1 + c.imin, c.imax = imin, imax +} + +// Check whether compaction is trivial. +func (c *compaction) trivial() bool { + return len(c.levels[0]) == 1 && len(c.levels[1]) == 0 && c.gp.size() <= c.maxGPOverlaps +} + +func (c *compaction) baseLevelForKey(ukey []byte) bool { + for level := c.sourceLevel + 2; level < len(c.v.levels); level++ { + tables := c.v.levels[level] + for c.tPtrs[level] < len(tables) { + t := tables[c.tPtrs[level]] + if c.s.icmp.uCompare(ukey, t.imax.ukey()) <= 0 { + // We've advanced far enough. + if c.s.icmp.uCompare(ukey, t.imin.ukey()) >= 0 { + // Key falls in this file's range, so definitely not base level. + return false + } + break + } + c.tPtrs[level]++ + } + } + return true +} + +func (c *compaction) shouldStopBefore(ikey internalKey) bool { + for ; c.gpi < len(c.gp); c.gpi++ { + gp := c.gp[c.gpi] + if c.s.icmp.Compare(ikey, gp.imax) <= 0 { + break + } + if c.seenKey { + c.gpOverlappedBytes += gp.size + } + } + c.seenKey = true + + if c.gpOverlappedBytes > c.maxGPOverlaps { + // Too much overlap for current output; start new output. + c.gpOverlappedBytes = 0 + return true + } + return false +} + +// Creates an iterator. +func (c *compaction) newIterator() iterator.Iterator { + // Creates iterator slice. + icap := len(c.levels) + if c.sourceLevel == 0 { + // Special case for level-0. + icap = len(c.levels[0]) + 1 + } + its := make([]iterator.Iterator, 0, icap) + + // Options. + ro := &opt.ReadOptions{ + DontFillCache: true, + Strict: opt.StrictOverride, + } + strict := c.s.o.GetStrict(opt.StrictCompaction) + if strict { + ro.Strict |= opt.StrictReader + } + + for i, tables := range c.levels { + if len(tables) == 0 { + continue + } + + // Level-0 is not sorted and may overlaps each other. + if c.sourceLevel+i == 0 { + for _, t := range tables { + its = append(its, c.s.tops.newIterator(t, nil, ro)) + } + } else { + it := iterator.NewIndexedIterator(tables.newIndexIterator(c.s.tops, c.s.icmp, nil, ro), strict) + its = append(its, it) + } + } + + return iterator.NewMergedIterator(its, c.s.icmp, strict) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session_record.go b/vendor/github.com/syndtr/goleveldb/leveldb/session_record.go new file mode 100644 index 0000000..854e1aa --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session_record.go @@ -0,0 +1,323 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bufio" + "encoding/binary" + "io" + "strings" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +type byteReader interface { + io.Reader + io.ByteReader +} + +// These numbers are written to disk and should not be changed. +const ( + recComparer = 1 + recJournalNum = 2 + recNextFileNum = 3 + recSeqNum = 4 + recCompPtr = 5 + recDelTable = 6 + recAddTable = 7 + // 8 was used for large value refs + recPrevJournalNum = 9 +) + +type cpRecord struct { + level int + ikey internalKey +} + +type atRecord struct { + level int + num int64 + size int64 + imin internalKey + imax internalKey +} + +type dtRecord struct { + level int + num int64 +} + +type sessionRecord struct { + hasRec int + comparer string + journalNum int64 + prevJournalNum int64 + nextFileNum int64 + seqNum uint64 + compPtrs []cpRecord + addedTables []atRecord + deletedTables []dtRecord + + scratch [binary.MaxVarintLen64]byte + err error +} + +func (p *sessionRecord) has(rec int) bool { + return p.hasRec&(1< +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "testing" +) + +func decodeEncode(v *sessionRecord) (res bool, err error) { + b := new(bytes.Buffer) + err = v.encode(b) + if err != nil { + return + } + v2 := &sessionRecord{} + err = v.decode(b) + if err != nil { + return + } + b2 := new(bytes.Buffer) + err = v2.encode(b2) + if err != nil { + return + } + return bytes.Equal(b.Bytes(), b2.Bytes()), nil +} + +func TestSessionRecord_EncodeDecode(t *testing.T) { + big := int64(1) << 50 + v := &sessionRecord{} + i := int64(0) + test := func() { + res, err := decodeEncode(v) + if err != nil { + t.Fatalf("error when testing encode/decode sessionRecord: %v", err) + } + if !res { + t.Error("encode/decode test failed at iteration:", i) + } + } + + for ; i < 4; i++ { + test() + v.addTable(3, big+300+i, big+400+i, + makeInternalKey(nil, []byte("foo"), uint64(big+500+1), keyTypeVal), + makeInternalKey(nil, []byte("zoo"), uint64(big+600+1), keyTypeDel)) + v.delTable(4, big+700+i) + v.addCompPtr(int(i), makeInternalKey(nil, []byte("x"), uint64(big+900+1), keyTypeVal)) + } + + v.setComparer("foo") + v.setJournalNum(big + 100) + v.setPrevJournalNum(big + 99) + v.setNextFileNum(big + 200) + v.setSeqNum(uint64(big + 1000)) + test() +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go b/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go new file mode 100644 index 0000000..fc56b63 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go @@ -0,0 +1,483 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +// Logging. + +type dropper struct { + s *session + fd storage.FileDesc +} + +func (d dropper) Drop(err error) { + if e, ok := err.(*journal.ErrCorrupted); ok { + d.s.logf("journal@drop %s-%d S·%s %q", d.fd.Type, d.fd.Num, shortenb(e.Size), e.Reason) + } else { + d.s.logf("journal@drop %s-%d %q", d.fd.Type, d.fd.Num, err) + } +} + +func (s *session) log(v ...interface{}) { s.stor.Log(fmt.Sprint(v...)) } +func (s *session) logf(format string, v ...interface{}) { s.stor.Log(fmt.Sprintf(format, v...)) } + +// File utils. + +func (s *session) newTemp() storage.FileDesc { + num := atomic.AddInt64(&s.stTempFileNum, 1) - 1 + return storage.FileDesc{Type: storage.TypeTemp, Num: num} +} + +// Session state. + +const ( + // maxCachedNumber represents the maximum number of version tasks + // that can be cached in the ref loop. + maxCachedNumber = 256 + + // maxCachedTime represents the maximum time for ref loop to cache + // a version task. + maxCachedTime = 5 * time.Minute +) + +// vDelta indicates the change information between the next version +// and the currently specified version +type vDelta struct { + vid int64 + added []int64 + deleted []int64 +} + +// vTask defines a version task for either reference or release. +type vTask struct { + vid int64 + files []tFiles + created time.Time +} + +func (s *session) refLoop() { + var ( + fileRef = make(map[int64]int) // Table file reference counter + ref = make(map[int64]*vTask) // Current referencing version store + deltas = make(map[int64]*vDelta) + referenced = make(map[int64]struct{}) + released = make(map[int64]*vDelta) // Released version that waiting for processing + abandoned = make(map[int64]struct{}) // Abandoned version id + next, last int64 + ) + // addFileRef adds file reference counter with specified file number and + // reference value + addFileRef := func(fnum int64, ref int) int { + ref += fileRef[fnum] + if ref > 0 { + fileRef[fnum] = ref + } else if ref == 0 { + delete(fileRef, fnum) + } else { + panic(fmt.Sprintf("negative ref: %v", fnum)) + } + return ref + } + // skipAbandoned skips useless abandoned version id. + skipAbandoned := func() bool { + if _, exist := abandoned[next]; exist { + delete(abandoned, next) + return true + } + return false + } + // applyDelta applies version change to current file reference. + applyDelta := func(d *vDelta) { + for _, t := range d.added { + addFileRef(t, 1) + } + for _, t := range d.deleted { + if addFileRef(t, -1) == 0 { + s.tops.remove(storage.FileDesc{Type: storage.TypeTable, Num: t}) + } + } + } + + timer := time.NewTimer(0) + <-timer.C // discard the initial tick + defer timer.Stop() + + // processTasks processes version tasks in strict order. + // + // If we want to use delta to reduce the cost of file references and dereferences, + // we must strictly follow the id of the version, otherwise some files that are + // being referenced will be deleted. + // + // In addition, some db operations (such as iterators) may cause a version to be + // referenced for a long time. In order to prevent such operations from blocking + // the entire processing queue, we will properly convert some of the version tasks + // into full file references and releases. + processTasks := func() { + timer.Reset(maxCachedTime) + // Make sure we don't cache too many version tasks. + for { + // Skip any abandoned version number to prevent blocking processing. + if skipAbandoned() { + next += 1 + continue + } + // Don't bother the version that has been released. + if _, exist := released[next]; exist { + break + } + // Ensure the specified version has been referenced. + if _, exist := ref[next]; !exist { + break + } + if last-next < maxCachedNumber && time.Since(ref[next].created) < maxCachedTime { + break + } + // Convert version task into full file references and releases mode. + // Reference version(i+1) first and wait version(i) to release. + // FileRef(i+1) = FileRef(i) + Delta(i) + for _, tt := range ref[next].files { + for _, t := range tt { + addFileRef(t.fd.Num, 1) + } + } + // Note, if some compactions take a long time, even more than 5 minutes, + // we may miss the corresponding delta information here. + // Fortunately it will not affect the correctness of the file reference, + // and we can apply the delta once we receive it. + if d := deltas[next]; d != nil { + applyDelta(d) + } + referenced[next] = struct{}{} + delete(ref, next) + delete(deltas, next) + next += 1 + } + + // Use delta information to process all released versions. + for { + if skipAbandoned() { + next += 1 + continue + } + if d, exist := released[next]; exist { + if d != nil { + applyDelta(d) + } + delete(released, next) + next += 1 + continue + } + return + } + } + + for { + processTasks() + + select { + case t := <-s.refCh: + if _, exist := ref[t.vid]; exist { + panic("duplicate reference request") + } + ref[t.vid] = t + if t.vid > last { + last = t.vid + } + + case d := <-s.deltaCh: + if _, exist := ref[d.vid]; !exist { + if _, exist2 := referenced[d.vid]; !exist2 { + panic("invalid release request") + } + // The reference opt is already expired, apply + // delta here. + applyDelta(d) + continue + } + deltas[d.vid] = d + + case t := <-s.relCh: + if _, exist := referenced[t.vid]; exist { + for _, tt := range t.files { + for _, t := range tt { + if addFileRef(t.fd.Num, -1) == 0 { + s.tops.remove(t.fd) + } + } + } + delete(referenced, t.vid) + continue + } + if _, exist := ref[t.vid]; !exist { + panic("invalid release request") + } + released[t.vid] = deltas[t.vid] + delete(deltas, t.vid) + delete(ref, t.vid) + + case id := <-s.abandon: + if id >= next { + abandoned[id] = struct{}{} + } + + case <-timer.C: + + case r := <-s.fileRefCh: + ref := make(map[int64]int) + for f, c := range fileRef { + ref[f] = c + } + r <- ref + + case <-s.closeC: + s.closeW.Done() + return + } + } +} + +// Get current version. This will incr version ref, must call +// version.release (exactly once) after use. +func (s *session) version() *version { + s.vmu.Lock() + defer s.vmu.Unlock() + s.stVersion.incref() + return s.stVersion +} + +func (s *session) tLen(level int) int { + s.vmu.Lock() + defer s.vmu.Unlock() + return s.stVersion.tLen(level) +} + +// Set current version to v. +func (s *session) setVersion(r *sessionRecord, v *version) { + s.vmu.Lock() + defer s.vmu.Unlock() + // Hold by session. It is important to call this first before releasing + // current version, otherwise the still used files might get released. + v.incref() + if s.stVersion != nil { + if r != nil { + var ( + added = make([]int64, 0, len(r.addedTables)) + deleted = make([]int64, 0, len(r.deletedTables)) + ) + for _, t := range r.addedTables { + added = append(added, t.num) + } + for _, t := range r.deletedTables { + deleted = append(deleted, t.num) + } + select { + case s.deltaCh <- &vDelta{vid: s.stVersion.id, added: added, deleted: deleted}: + case <-v.s.closeC: + s.log("reference loop already exist") + } + } + // Release current version. + s.stVersion.releaseNB() + } + s.stVersion = v +} + +// Get current unused file number. +func (s *session) nextFileNum() int64 { + return atomic.LoadInt64(&s.stNextFileNum) +} + +// Set current unused file number to num. +func (s *session) setNextFileNum(num int64) { + atomic.StoreInt64(&s.stNextFileNum, num) +} + +// Mark file number as used. +func (s *session) markFileNum(num int64) { + nextFileNum := num + 1 + for { + old, x := atomic.LoadInt64(&s.stNextFileNum), nextFileNum + if old > x { + x = old + } + if atomic.CompareAndSwapInt64(&s.stNextFileNum, old, x) { + break + } + } +} + +// Allocate a file number. +func (s *session) allocFileNum() int64 { + return atomic.AddInt64(&s.stNextFileNum, 1) - 1 +} + +// Reuse given file number. +func (s *session) reuseFileNum(num int64) { + for { + old, x := atomic.LoadInt64(&s.stNextFileNum), num + if old != x+1 { + x = old + } + if atomic.CompareAndSwapInt64(&s.stNextFileNum, old, x) { + break + } + } +} + +// Set compaction ptr at given level; need external synchronization. +func (s *session) setCompPtr(level int, ik internalKey) { + if level >= len(s.stCompPtrs) { + newCompPtrs := make([]internalKey, level+1) + copy(newCompPtrs, s.stCompPtrs) + s.stCompPtrs = newCompPtrs + } + s.stCompPtrs[level] = append(internalKey{}, ik...) +} + +// Get compaction ptr at given level; need external synchronization. +func (s *session) getCompPtr(level int) internalKey { + if level >= len(s.stCompPtrs) { + return nil + } + return s.stCompPtrs[level] +} + +// Manifest related utils. + +// Fill given session record obj with current states; need external +// synchronization. +func (s *session) fillRecord(r *sessionRecord, snapshot bool) { + r.setNextFileNum(s.nextFileNum()) + + if snapshot { + if !r.has(recJournalNum) { + r.setJournalNum(s.stJournalNum) + } + + if !r.has(recSeqNum) { + r.setSeqNum(s.stSeqNum) + } + + for level, ik := range s.stCompPtrs { + if ik != nil { + r.addCompPtr(level, ik) + } + } + + r.setComparer(s.icmp.uName()) + } +} + +// Mark if record has been committed, this will update session state; +// need external synchronization. +func (s *session) recordCommited(rec *sessionRecord) { + if rec.has(recJournalNum) { + s.stJournalNum = rec.journalNum + } + + if rec.has(recPrevJournalNum) { + s.stPrevJournalNum = rec.prevJournalNum + } + + if rec.has(recSeqNum) { + s.stSeqNum = rec.seqNum + } + + for _, r := range rec.compPtrs { + s.setCompPtr(r.level, internalKey(r.ikey)) + } +} + +// Create a new manifest file; need external synchronization. +func (s *session) newManifest(rec *sessionRecord, v *version) (err error) { + fd := storage.FileDesc{Type: storage.TypeManifest, Num: s.allocFileNum()} + writer, err := s.stor.Create(fd) + if err != nil { + return + } + jw := journal.NewWriter(writer) + + if v == nil { + v = s.version() + defer v.release() + } + if rec == nil { + rec = &sessionRecord{} + } + s.fillRecord(rec, true) + v.fillRecord(rec) + + defer func() { + if err == nil { + s.recordCommited(rec) + if s.manifest != nil { + s.manifest.Close() + } + if s.manifestWriter != nil { + s.manifestWriter.Close() + } + if !s.manifestFd.Zero() { + s.stor.Remove(s.manifestFd) + } + s.manifestFd = fd + s.manifestWriter = writer + s.manifest = jw + } else { + writer.Close() + s.stor.Remove(fd) + s.reuseFileNum(fd.Num) + } + }() + + w, err := jw.Next() + if err != nil { + return + } + err = rec.encode(w) + if err != nil { + return + } + err = jw.Flush() + if err != nil { + return + } + err = s.stor.SetMeta(fd) + return +} + +// Flush record to disk. +func (s *session) flushManifest(rec *sessionRecord) (err error) { + s.fillRecord(rec, false) + w, err := s.manifest.Next() + if err != nil { + return + } + err = rec.encode(w) + if err != nil { + return + } + err = s.manifest.Flush() + if err != nil { + return + } + if !s.o.GetNoSync() { + err = s.manifestWriter.Sync() + if err != nil { + return + } + } + s.recordCommited(rec) + return +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage.go new file mode 100644 index 0000000..d45fb5d --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage.go @@ -0,0 +1,63 @@ +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/storage" + "sync/atomic" +) + +type iStorage struct { + storage.Storage + read uint64 + write uint64 +} + +func (c *iStorage) Open(fd storage.FileDesc) (storage.Reader, error) { + r, err := c.Storage.Open(fd) + return &iStorageReader{r, c}, err +} + +func (c *iStorage) Create(fd storage.FileDesc) (storage.Writer, error) { + w, err := c.Storage.Create(fd) + return &iStorageWriter{w, c}, err +} + +func (c *iStorage) reads() uint64 { + return atomic.LoadUint64(&c.read) +} + +func (c *iStorage) writes() uint64 { + return atomic.LoadUint64(&c.write) +} + +// newIStorage returns the given storage wrapped by iStorage. +func newIStorage(s storage.Storage) *iStorage { + return &iStorage{s, 0, 0} +} + +type iStorageReader struct { + storage.Reader + c *iStorage +} + +func (r *iStorageReader) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + atomic.AddUint64(&r.c.read, uint64(n)) + return n, err +} + +func (r *iStorageReader) ReadAt(p []byte, off int64) (n int, err error) { + n, err = r.Reader.ReadAt(p, off) + atomic.AddUint64(&r.c.read, uint64(n)) + return n, err +} + +type iStorageWriter struct { + storage.Writer + c *iStorage +} + +func (w *iStorageWriter) Write(p []byte) (n int, err error) { + n, err = w.Writer.Write(p) + atomic.AddUint64(&w.c.write, uint64(n)) + return n, err +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go new file mode 100644 index 0000000..9ba71fd --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go @@ -0,0 +1,671 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reservefs. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +var ( + errFileOpen = errors.New("leveldb/storage: file still open") + errReadOnly = errors.New("leveldb/storage: storage is read-only") +) + +type fileLock interface { + release() error +} + +type fileStorageLock struct { + fs *fileStorage +} + +func (lock *fileStorageLock) Unlock() { + if lock.fs != nil { + lock.fs.mu.Lock() + defer lock.fs.mu.Unlock() + if lock.fs.slock == lock { + lock.fs.slock = nil + } + } +} + +type int64Slice []int64 + +func (p int64Slice) Len() int { return len(p) } +func (p int64Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p int64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func writeFileSynced(filename string, data []byte, perm os.FileMode) error { + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + } + if err1 := f.Sync(); err == nil { + err = err1 + } + if err1 := f.Close(); err == nil { + err = err1 + } + return err +} + +const logSizeThreshold = 1024 * 1024 // 1 MiB + +// fileStorage is a file-system backed storage. +type fileStorage struct { + path string + readOnly bool + + mu sync.Mutex + flock fileLock + slock *fileStorageLock + logw *os.File + logSize int64 + buf []byte + // Opened file counter; if open < 0 means closed. + open int + day int +} + +// OpenFile returns a new filesystem-backed storage implementation with the given +// path. This also acquire a file lock, so any subsequent attempt to open the +// same path will fail. +// +// The storage must be closed after use, by calling Close method. +func OpenFile(path string, readOnly bool) (Storage, error) { + if fi, err := os.Stat(path); err == nil { + if !fi.IsDir() { + return nil, fmt.Errorf("leveldb/storage: open %s: not a directory", path) + } + } else if os.IsNotExist(err) && !readOnly { + if err := os.MkdirAll(path, 0755); err != nil { + return nil, err + } + } else { + return nil, err + } + + flock, err := newFileLock(filepath.Join(path, "LOCK"), readOnly) + if err != nil { + return nil, err + } + + defer func() { + if err != nil { + flock.release() + } + }() + + var ( + logw *os.File + logSize int64 + ) + if !readOnly { + logw, err = os.OpenFile(filepath.Join(path, "LOG"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return nil, err + } + logSize, err = logw.Seek(0, os.SEEK_END) + if err != nil { + logw.Close() + return nil, err + } + } + + fs := &fileStorage{ + path: path, + readOnly: readOnly, + flock: flock, + logw: logw, + logSize: logSize, + } + runtime.SetFinalizer(fs, (*fileStorage).Close) + return fs, nil +} + +func (fs *fileStorage) Lock() (Locker, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + if fs.readOnly { + return &fileStorageLock{}, nil + } + if fs.slock != nil { + return nil, ErrLocked + } + fs.slock = &fileStorageLock{fs: fs} + return fs.slock, nil +} + +func itoa(buf []byte, i int, wid int) []byte { + u := uint(i) + if u == 0 && wid <= 1 { + return append(buf, '0') + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; u > 0 || wid > 0; u /= 10 { + bp-- + wid-- + b[bp] = byte(u%10) + '0' + } + return append(buf, b[bp:]...) +} + +func (fs *fileStorage) printDay(t time.Time) { + if fs.day == t.Day() { + return + } + fs.day = t.Day() + fs.logw.Write([]byte("=============== " + t.Format("Jan 2, 2006 (MST)") + " ===============\n")) +} + +func (fs *fileStorage) doLog(t time.Time, str string) { + if fs.logSize > logSizeThreshold { + // Rotate log file. + fs.logw.Close() + fs.logw = nil + fs.logSize = 0 + rename(filepath.Join(fs.path, "LOG"), filepath.Join(fs.path, "LOG.old")) + } + if fs.logw == nil { + var err error + fs.logw, err = os.OpenFile(filepath.Join(fs.path, "LOG"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return + } + // Force printDay on new log file. + fs.day = 0 + } + fs.printDay(t) + hour, min, sec := t.Clock() + msec := t.Nanosecond() / 1e3 + // time + fs.buf = itoa(fs.buf[:0], hour, 2) + fs.buf = append(fs.buf, ':') + fs.buf = itoa(fs.buf, min, 2) + fs.buf = append(fs.buf, ':') + fs.buf = itoa(fs.buf, sec, 2) + fs.buf = append(fs.buf, '.') + fs.buf = itoa(fs.buf, msec, 6) + fs.buf = append(fs.buf, ' ') + // write + fs.buf = append(fs.buf, []byte(str)...) + fs.buf = append(fs.buf, '\n') + n, _ := fs.logw.Write(fs.buf) + fs.logSize += int64(n) +} + +func (fs *fileStorage) Log(str string) { + if !fs.readOnly { + t := time.Now() + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return + } + fs.doLog(t, str) + } +} + +func (fs *fileStorage) log(str string) { + if !fs.readOnly { + fs.doLog(time.Now(), str) + } +} + +func (fs *fileStorage) setMeta(fd FileDesc) error { + content := fsGenName(fd) + "\n" + // Check and backup old CURRENT file. + currentPath := filepath.Join(fs.path, "CURRENT") + if _, err := os.Stat(currentPath); err == nil { + b, err := ioutil.ReadFile(currentPath) + if err != nil { + fs.log(fmt.Sprintf("backup CURRENT: %v", err)) + return err + } + if string(b) == content { + // Content not changed, do nothing. + return nil + } + if err := writeFileSynced(currentPath+".bak", b, 0644); err != nil { + fs.log(fmt.Sprintf("backup CURRENT: %v", err)) + return err + } + } else if !os.IsNotExist(err) { + return err + } + path := fmt.Sprintf("%s.%d", filepath.Join(fs.path, "CURRENT"), fd.Num) + if err := writeFileSynced(path, []byte(content), 0644); err != nil { + fs.log(fmt.Sprintf("create CURRENT.%d: %v", fd.Num, err)) + return err + } + // Replace CURRENT file. + if err := rename(path, currentPath); err != nil { + fs.log(fmt.Sprintf("rename CURRENT.%d: %v", fd.Num, err)) + return err + } + // Sync root directory. + if err := syncDir(fs.path); err != nil { + fs.log(fmt.Sprintf("syncDir: %v", err)) + return err + } + return nil +} + +func (fs *fileStorage) SetMeta(fd FileDesc) error { + if !FileDescOk(fd) { + return ErrInvalidFile + } + if fs.readOnly { + return errReadOnly + } + + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + return fs.setMeta(fd) +} + +func (fs *fileStorage) GetMeta() (FileDesc, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return FileDesc{}, ErrClosed + } + dir, err := os.Open(fs.path) + if err != nil { + return FileDesc{}, err + } + names, err := dir.Readdirnames(0) + // Close the dir first before checking for Readdirnames error. + if ce := dir.Close(); ce != nil { + fs.log(fmt.Sprintf("close dir: %v", ce)) + } + if err != nil { + return FileDesc{}, err + } + // Try this in order: + // - CURRENT.[0-9]+ ('pending rename' file, descending order) + // - CURRENT + // - CURRENT.bak + // + // Skip corrupted file or file that point to a missing target file. + type currentFile struct { + name string + fd FileDesc + } + tryCurrent := func(name string) (*currentFile, error) { + b, err := ioutil.ReadFile(filepath.Join(fs.path, name)) + if err != nil { + if os.IsNotExist(err) { + err = os.ErrNotExist + } + return nil, err + } + var fd FileDesc + if len(b) < 1 || b[len(b)-1] != '\n' || !fsParseNamePtr(string(b[:len(b)-1]), &fd) { + fs.log(fmt.Sprintf("%s: corrupted content: %q", name, b)) + err := &ErrCorrupted{ + Err: errors.New("leveldb/storage: corrupted or incomplete CURRENT file"), + } + return nil, err + } + if _, err := os.Stat(filepath.Join(fs.path, fsGenName(fd))); err != nil { + if os.IsNotExist(err) { + fs.log(fmt.Sprintf("%s: missing target file: %s", name, fd)) + err = os.ErrNotExist + } + return nil, err + } + return ¤tFile{name: name, fd: fd}, nil + } + tryCurrents := func(names []string) (*currentFile, error) { + var ( + cur *currentFile + // Last corruption error. + lastCerr error + ) + for _, name := range names { + var err error + cur, err = tryCurrent(name) + if err == nil { + break + } else if err == os.ErrNotExist { + // Fallback to the next file. + } else if isCorrupted(err) { + lastCerr = err + // Fallback to the next file. + } else { + // In case the error is due to permission, etc. + return nil, err + } + } + if cur == nil { + err := os.ErrNotExist + if lastCerr != nil { + err = lastCerr + } + return nil, err + } + return cur, nil + } + + // Try 'pending rename' files. + var nums []int64 + for _, name := range names { + if strings.HasPrefix(name, "CURRENT.") && name != "CURRENT.bak" { + i, err := strconv.ParseInt(name[8:], 10, 64) + if err == nil { + nums = append(nums, i) + } + } + } + var ( + pendCur *currentFile + pendErr = os.ErrNotExist + pendNames []string + ) + if len(nums) > 0 { + sort.Sort(sort.Reverse(int64Slice(nums))) + pendNames = make([]string, len(nums)) + for i, num := range nums { + pendNames[i] = fmt.Sprintf("CURRENT.%d", num) + } + pendCur, pendErr = tryCurrents(pendNames) + if pendErr != nil && pendErr != os.ErrNotExist && !isCorrupted(pendErr) { + return FileDesc{}, pendErr + } + } + + // Try CURRENT and CURRENT.bak. + curCur, curErr := tryCurrents([]string{"CURRENT", "CURRENT.bak"}) + if curErr != nil && curErr != os.ErrNotExist && !isCorrupted(curErr) { + return FileDesc{}, curErr + } + + // pendCur takes precedence, but guards against obsolete pendCur. + if pendCur != nil && (curCur == nil || pendCur.fd.Num > curCur.fd.Num) { + curCur = pendCur + } + + if curCur != nil { + // Restore CURRENT file to proper state. + if !fs.readOnly && (curCur.name != "CURRENT" || len(pendNames) != 0) { + // Ignore setMeta errors, however don't delete obsolete files if we + // catch error. + if err := fs.setMeta(curCur.fd); err == nil { + // Remove 'pending rename' files. + for _, name := range pendNames { + if err := os.Remove(filepath.Join(fs.path, name)); err != nil { + fs.log(fmt.Sprintf("remove %s: %v", name, err)) + } + } + } + } + return curCur.fd, nil + } + + // Nothing found. + if isCorrupted(pendErr) { + return FileDesc{}, pendErr + } + return FileDesc{}, curErr +} + +func (fs *fileStorage) List(ft FileType) (fds []FileDesc, err error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + dir, err := os.Open(fs.path) + if err != nil { + return + } + names, err := dir.Readdirnames(0) + // Close the dir first before checking for Readdirnames error. + if cerr := dir.Close(); cerr != nil { + fs.log(fmt.Sprintf("close dir: %v", cerr)) + } + if err == nil { + for _, name := range names { + if fd, ok := fsParseName(name); ok && fd.Type&ft != 0 { + fds = append(fds, fd) + } + } + } + return +} + +func (fs *fileStorage) Open(fd FileDesc) (Reader, error) { + if !FileDescOk(fd) { + return nil, ErrInvalidFile + } + + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + of, err := os.OpenFile(filepath.Join(fs.path, fsGenName(fd)), os.O_RDONLY, 0) + if err != nil { + if fsHasOldName(fd) && os.IsNotExist(err) { + of, err = os.OpenFile(filepath.Join(fs.path, fsGenOldName(fd)), os.O_RDONLY, 0) + if err == nil { + goto ok + } + } + return nil, err + } +ok: + fs.open++ + return &fileWrap{File: of, fs: fs, fd: fd}, nil +} + +func (fs *fileStorage) Create(fd FileDesc) (Writer, error) { + if !FileDescOk(fd) { + return nil, ErrInvalidFile + } + if fs.readOnly { + return nil, errReadOnly + } + + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + of, err := os.OpenFile(filepath.Join(fs.path, fsGenName(fd)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return nil, err + } + fs.open++ + return &fileWrap{File: of, fs: fs, fd: fd}, nil +} + +func (fs *fileStorage) Remove(fd FileDesc) error { + if !FileDescOk(fd) { + return ErrInvalidFile + } + if fs.readOnly { + return errReadOnly + } + + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + err := os.Remove(filepath.Join(fs.path, fsGenName(fd))) + if err != nil { + if fsHasOldName(fd) && os.IsNotExist(err) { + if e1 := os.Remove(filepath.Join(fs.path, fsGenOldName(fd))); !os.IsNotExist(e1) { + fs.log(fmt.Sprintf("remove %s: %v (old name)", fd, err)) + err = e1 + } + } else { + fs.log(fmt.Sprintf("remove %s: %v", fd, err)) + } + } + return err +} + +func (fs *fileStorage) Rename(oldfd, newfd FileDesc) error { + if !FileDescOk(oldfd) || !FileDescOk(newfd) { + return ErrInvalidFile + } + if oldfd == newfd { + return nil + } + if fs.readOnly { + return errReadOnly + } + + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + return rename(filepath.Join(fs.path, fsGenName(oldfd)), filepath.Join(fs.path, fsGenName(newfd))) +} + +func (fs *fileStorage) Close() error { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + // Clear the finalizer. + runtime.SetFinalizer(fs, nil) + + if fs.open > 0 { + fs.log(fmt.Sprintf("close: warning, %d files still open", fs.open)) + } + fs.open = -1 + if fs.logw != nil { + fs.logw.Close() + } + return fs.flock.release() +} + +type fileWrap struct { + *os.File + fs *fileStorage + fd FileDesc + closed bool +} + +func (fw *fileWrap) Sync() error { + if err := fw.File.Sync(); err != nil { + return err + } + if fw.fd.Type == TypeManifest { + // Also sync parent directory if file type is manifest. + // See: https://code.google.com/p/leveldb/issues/detail?id=190. + if err := syncDir(fw.fs.path); err != nil { + fw.fs.log(fmt.Sprintf("syncDir: %v", err)) + return err + } + } + return nil +} + +func (fw *fileWrap) Close() error { + fw.fs.mu.Lock() + defer fw.fs.mu.Unlock() + if fw.closed { + return ErrClosed + } + fw.closed = true + fw.fs.open-- + err := fw.File.Close() + if err != nil { + fw.fs.log(fmt.Sprintf("close %s: %v", fw.fd, err)) + } + return err +} + +func fsGenName(fd FileDesc) string { + switch fd.Type { + case TypeManifest: + return fmt.Sprintf("MANIFEST-%06d", fd.Num) + case TypeJournal: + return fmt.Sprintf("%06d.log", fd.Num) + case TypeTable: + return fmt.Sprintf("%06d.ldb", fd.Num) + case TypeTemp: + return fmt.Sprintf("%06d.tmp", fd.Num) + default: + panic("invalid file type") + } +} + +func fsHasOldName(fd FileDesc) bool { + return fd.Type == TypeTable +} + +func fsGenOldName(fd FileDesc) string { + switch fd.Type { + case TypeTable: + return fmt.Sprintf("%06d.sst", fd.Num) + } + return fsGenName(fd) +} + +func fsParseName(name string) (fd FileDesc, ok bool) { + var tail string + _, err := fmt.Sscanf(name, "%d.%s", &fd.Num, &tail) + if err == nil { + switch tail { + case "log": + fd.Type = TypeJournal + case "ldb", "sst": + fd.Type = TypeTable + case "tmp": + fd.Type = TypeTemp + default: + return + } + return fd, true + } + n, _ := fmt.Sscanf(name, "MANIFEST-%d%s", &fd.Num, &tail) + if n == 1 { + fd.Type = TypeManifest + return fd, true + } + return +} + +func fsParseNamePtr(name string, fd *FileDesc) bool { + _fd, ok := fsParseName(name) + if fd != nil { + *fd = _fd + } + return ok +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_nacl.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_nacl.go new file mode 100644 index 0000000..5545aee --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_nacl.go @@ -0,0 +1,34 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build nacl + +package storage + +import ( + "os" + "syscall" +) + +func newFileLock(path string, readOnly bool) (fl fileLock, err error) { + return nil, syscall.ENOTSUP +} + +func setFileLock(f *os.File, readOnly, lock bool) error { + return syscall.ENOTSUP +} + +func rename(oldpath, newpath string) error { + return syscall.ENOTSUP +} + +func isErrInvalid(err error) bool { + return false +} + +func syncDir(name string) error { + return syscall.ENOTSUP +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go new file mode 100644 index 0000000..b829798 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go @@ -0,0 +1,63 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "os" +) + +type plan9FileLock struct { + f *os.File +} + +func (fl *plan9FileLock) release() error { + return fl.f.Close() +} + +func newFileLock(path string, readOnly bool) (fl fileLock, err error) { + var ( + flag int + perm os.FileMode + ) + if readOnly { + flag = os.O_RDONLY + } else { + flag = os.O_RDWR + perm = os.ModeExclusive + } + f, err := os.OpenFile(path, flag, perm) + if os.IsNotExist(err) { + f, err = os.OpenFile(path, flag|os.O_CREATE, perm|0644) + } + if err != nil { + return + } + fl = &plan9FileLock{f: f} + return +} + +func rename(oldpath, newpath string) error { + if _, err := os.Stat(newpath); err == nil { + if err := os.Remove(newpath); err != nil { + return err + } + } + + return os.Rename(oldpath, newpath) +} + +func syncDir(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go new file mode 100644 index 0000000..79901ee --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go @@ -0,0 +1,81 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build solaris + +package storage + +import ( + "os" + "syscall" +) + +type unixFileLock struct { + f *os.File +} + +func (fl *unixFileLock) release() error { + if err := setFileLock(fl.f, false, false); err != nil { + return err + } + return fl.f.Close() +} + +func newFileLock(path string, readOnly bool) (fl fileLock, err error) { + var flag int + if readOnly { + flag = os.O_RDONLY + } else { + flag = os.O_RDWR + } + f, err := os.OpenFile(path, flag, 0) + if os.IsNotExist(err) { + f, err = os.OpenFile(path, flag|os.O_CREATE, 0644) + } + if err != nil { + return + } + err = setFileLock(f, readOnly, true) + if err != nil { + f.Close() + return + } + fl = &unixFileLock{f: f} + return +} + +func setFileLock(f *os.File, readOnly, lock bool) error { + flock := syscall.Flock_t{ + Type: syscall.F_UNLCK, + Start: 0, + Len: 0, + Whence: 1, + } + if lock { + if readOnly { + flock.Type = syscall.F_RDLCK + } else { + flock.Type = syscall.F_WRLCK + } + } + return syscall.FcntlFlock(f.Fd(), syscall.F_SETLK, &flock) +} + +func rename(oldpath, newpath string) error { + return os.Rename(oldpath, newpath) +} + +func syncDir(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go new file mode 100644 index 0000000..2e60315 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go @@ -0,0 +1,400 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "fmt" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "strings" + "testing" +) + +var cases = []struct { + oldName []string + name string + ftype FileType + num int64 +}{ + {nil, "000100.log", TypeJournal, 100}, + {nil, "000000.log", TypeJournal, 0}, + {[]string{"000000.sst"}, "000000.ldb", TypeTable, 0}, + {nil, "MANIFEST-000002", TypeManifest, 2}, + {nil, "MANIFEST-000007", TypeManifest, 7}, + {nil, "9223372036854775807.log", TypeJournal, 9223372036854775807}, + {nil, "000100.tmp", TypeTemp, 100}, +} + +var invalidCases = []string{ + "", + "foo", + "foo-dx-100.log", + ".log", + "", + "manifest", + "CURREN", + "CURRENTX", + "MANIFES", + "MANIFEST", + "MANIFEST-", + "XMANIFEST-3", + "MANIFEST-3x", + "LOC", + "LOCKx", + "LO", + "LOGx", + "18446744073709551616.log", + "184467440737095516150.log", + "100", + "100.", + "100.lop", +} + +func tempDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "goleveldb-") + if err != nil { + t.Fatal(t) + } + t.Log("Using temp-dir:", dir) + return dir +} + +func TestFileStorage_CreateFileName(t *testing.T) { + for _, c := range cases { + if name := fsGenName(FileDesc{c.ftype, c.num}); name != c.name { + t.Errorf("invalid filename got '%s', want '%s'", name, c.name) + } + } +} + +func TestFileStorage_MetaSetGet(t *testing.T) { + temp := tempDir(t) + fs, err := OpenFile(temp, false) + if err != nil { + t.Fatal("OpenFile: got error: ", err) + } + + for i := 0; i < 10; i++ { + num := rand.Int63() + fd := FileDesc{Type: TypeManifest, Num: num} + w, err := fs.Create(fd) + if err != nil { + t.Fatalf("Create(%d): got error: %v", i, err) + } + w.Write([]byte("TEST")) + w.Close() + if err := fs.SetMeta(fd); err != nil { + t.Fatalf("SetMeta(%d): got error: %v", i, err) + } + rfd, err := fs.GetMeta() + if err != nil { + t.Fatalf("GetMeta(%d): got error: %v", i, err) + } + if fd != rfd { + t.Fatalf("Invalid meta (%d): got '%s', want '%s'", i, rfd, fd) + } + } + os.RemoveAll(temp) +} + +func TestFileStorage_Meta(t *testing.T) { + type current struct { + num int64 + backup bool + current bool + manifest bool + corrupt bool + } + type testCase struct { + currents []current + notExist bool + corrupt bool + expect int64 + } + cases := []testCase{ + { + currents: []current{ + {num: 2, backup: true, manifest: true}, + {num: 1, current: true}, + }, + expect: 2, + }, + { + currents: []current{ + {num: 2, backup: true, manifest: true}, + {num: 1, current: true, manifest: true}, + }, + expect: 1, + }, + { + currents: []current{ + {num: 2, manifest: true}, + {num: 3, manifest: true}, + {num: 4, current: true, manifest: true}, + }, + expect: 4, + }, + { + currents: []current{ + {num: 2, manifest: true}, + {num: 3, manifest: true}, + {num: 4, current: true, manifest: true, corrupt: true}, + }, + expect: 3, + }, + { + currents: []current{ + {num: 2, manifest: true}, + {num: 3, manifest: true}, + {num: 5, current: true, manifest: true, corrupt: true}, + {num: 4, backup: true, manifest: true}, + }, + expect: 4, + }, + { + currents: []current{ + {num: 4, manifest: true}, + {num: 3, manifest: true}, + {num: 2, current: true, manifest: true}, + }, + expect: 4, + }, + { + currents: []current{ + {num: 4, manifest: true, corrupt: true}, + {num: 3, manifest: true}, + {num: 2, current: true, manifest: true}, + }, + expect: 3, + }, + { + currents: []current{ + {num: 4, manifest: true, corrupt: true}, + {num: 3, manifest: true, corrupt: true}, + {num: 2, current: true, manifest: true}, + }, + expect: 2, + }, + { + currents: []current{ + {num: 4}, + {num: 3, manifest: true}, + {num: 2, current: true, manifest: true}, + }, + expect: 3, + }, + { + currents: []current{ + {num: 4}, + {num: 3, manifest: true}, + {num: 6, current: true}, + {num: 5, backup: true, manifest: true}, + }, + expect: 5, + }, + { + currents: []current{ + {num: 4}, + {num: 3}, + {num: 6, current: true}, + {num: 5, backup: true}, + }, + notExist: true, + }, + { + currents: []current{ + {num: 4, corrupt: true}, + {num: 3}, + {num: 6, current: true}, + {num: 5, backup: true}, + }, + corrupt: true, + }, + } + for i, tc := range cases { + t.Logf("Test-%d", i) + temp := tempDir(t) + fs, err := OpenFile(temp, false) + if err != nil { + t.Fatal("OpenFile: got error: ", err) + } + for _, cur := range tc.currents { + var curName string + switch { + case cur.current: + curName = "CURRENT" + case cur.backup: + curName = "CURRENT.bak" + default: + curName = fmt.Sprintf("CURRENT.%d", cur.num) + } + fd := FileDesc{Type: TypeManifest, Num: cur.num} + content := fmt.Sprintf("%s\n", fsGenName(fd)) + if cur.corrupt { + content = content[:len(content)-1-rand.Intn(3)] + } + if err := ioutil.WriteFile(filepath.Join(temp, curName), []byte(content), 0644); err != nil { + t.Fatal(err) + } + if cur.manifest { + w, err := fs.Create(fd) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write([]byte("TEST")); err != nil { + t.Fatal(err) + } + w.Close() + } + } + ret, err := fs.GetMeta() + if tc.notExist { + if err != os.ErrNotExist { + t.Fatalf("expect ErrNotExist, got: %v", err) + } + } else if tc.corrupt { + if !isCorrupted(err) { + t.Fatalf("expect ErrCorrupted, got: %v", err) + } + } else { + if err != nil { + t.Fatal(err) + } + if ret.Type != TypeManifest { + t.Fatalf("expecting manifest, got: %s", ret.Type) + } + if ret.Num != tc.expect { + t.Fatalf("invalid num, expect=%d got=%d", tc.expect, ret.Num) + } + fis, err := ioutil.ReadDir(temp) + if err != nil { + t.Fatal(err) + } + for _, fi := range fis { + if strings.HasPrefix(fi.Name(), "CURRENT") { + switch fi.Name() { + case "CURRENT", "CURRENT.bak": + default: + t.Fatalf("found rouge CURRENT file: %s", fi.Name()) + } + } + t.Logf("-> %s", fi.Name()) + } + } + os.RemoveAll(temp) + } +} + +func TestFileStorage_ParseFileName(t *testing.T) { + for _, c := range cases { + for _, name := range append([]string{c.name}, c.oldName...) { + fd, ok := fsParseName(name) + if !ok { + t.Errorf("cannot parse filename '%s'", name) + continue + } + if fd.Type != c.ftype { + t.Errorf("filename '%s' invalid type got '%d', want '%d'", name, fd.Type, c.ftype) + } + if fd.Num != c.num { + t.Errorf("filename '%s' invalid number got '%d', want '%d'", name, fd.Num, c.num) + } + } + } +} + +func TestFileStorage_InvalidFileName(t *testing.T) { + for _, name := range invalidCases { + if fsParseNamePtr(name, nil) { + t.Errorf("filename '%s' should be invalid", name) + } + } +} + +func TestFileStorage_Locking(t *testing.T) { + temp := tempDir(t) + defer os.RemoveAll(temp) + + p1, err := OpenFile(temp, false) + if err != nil { + t.Fatal("OpenFile(1): got error: ", err) + } + + p2, err := OpenFile(temp, false) + if err != nil { + t.Logf("OpenFile(2): got error: %s (expected)", err) + } else { + p2.Close() + p1.Close() + t.Fatal("OpenFile(2): expect error") + } + + p1.Close() + + p3, err := OpenFile(temp, false) + if err != nil { + t.Fatal("OpenFile(3): got error: ", err) + } + defer p3.Close() + + l, err := p3.Lock() + if err != nil { + t.Fatal("storage lock failed(1): ", err) + } + _, err = p3.Lock() + if err == nil { + t.Fatal("expect error for second storage lock attempt") + } else { + t.Logf("storage lock got error: %s (expected)", err) + } + l.Unlock() + _, err = p3.Lock() + if err != nil { + t.Fatal("storage lock failed(2): ", err) + } +} + +func TestFileStorage_ReadOnlyLocking(t *testing.T) { + temp := tempDir(t) + defer os.RemoveAll(temp) + + p1, err := OpenFile(temp, false) + if err != nil { + t.Fatal("OpenFile(1): got error: ", err) + } + + _, err = OpenFile(temp, true) + if err != nil { + t.Logf("OpenFile(2): got error: %s (expected)", err) + } else { + t.Fatal("OpenFile(2): expect error") + } + + p1.Close() + + p3, err := OpenFile(temp, true) + if err != nil { + t.Fatal("OpenFile(3): got error: ", err) + } + + p4, err := OpenFile(temp, true) + if err != nil { + t.Fatal("OpenFile(4): got error: ", err) + } + + _, err = OpenFile(temp, false) + if err != nil { + t.Logf("OpenFile(5): got error: %s (expected)", err) + } else { + t.Fatal("OpenFile(2): expect error") + } + + p3.Close() + p4.Close() +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go new file mode 100644 index 0000000..d75f66a --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go @@ -0,0 +1,98 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package storage + +import ( + "os" + "syscall" +) + +type unixFileLock struct { + f *os.File +} + +func (fl *unixFileLock) release() error { + if err := setFileLock(fl.f, false, false); err != nil { + return err + } + return fl.f.Close() +} + +func newFileLock(path string, readOnly bool) (fl fileLock, err error) { + var flag int + if readOnly { + flag = os.O_RDONLY + } else { + flag = os.O_RDWR + } + f, err := os.OpenFile(path, flag, 0) + if os.IsNotExist(err) { + f, err = os.OpenFile(path, flag|os.O_CREATE, 0644) + } + if err != nil { + return + } + err = setFileLock(f, readOnly, true) + if err != nil { + f.Close() + return + } + fl = &unixFileLock{f: f} + return +} + +func setFileLock(f *os.File, readOnly, lock bool) error { + how := syscall.LOCK_UN + if lock { + if readOnly { + how = syscall.LOCK_SH + } else { + how = syscall.LOCK_EX + } + } + return syscall.Flock(int(f.Fd()), how|syscall.LOCK_NB) +} + +func rename(oldpath, newpath string) error { + return os.Rename(oldpath, newpath) +} + +func isErrInvalid(err error) bool { + if err == os.ErrInvalid { + return true + } + // Go < 1.8 + if syserr, ok := err.(*os.SyscallError); ok && syserr.Err == syscall.EINVAL { + return true + } + // Go >= 1.8 returns *os.PathError instead + if patherr, ok := err.(*os.PathError); ok && patherr.Err == syscall.EINVAL { + return true + } + return false +} + +func syncDir(name string) error { + // As per fsync manpage, Linux seems to expect fsync on directory, however + // some system don't support this, so we will ignore syscall.EINVAL. + // + // From fsync(2): + // Calling fsync() does not necessarily ensure that the entry in the + // directory containing the file has also reached disk. For that an + // explicit fsync() on a file descriptor for the directory is also needed. + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil && !isErrInvalid(err) { + return err + } + return nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go new file mode 100644 index 0000000..899335f --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go @@ -0,0 +1,78 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "syscall" + "unsafe" +) + +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + + procMoveFileExW = modkernel32.NewProc("MoveFileExW") +) + +const ( + _MOVEFILE_REPLACE_EXISTING = 1 +) + +type windowsFileLock struct { + fd syscall.Handle +} + +func (fl *windowsFileLock) release() error { + return syscall.Close(fl.fd) +} + +func newFileLock(path string, readOnly bool) (fl fileLock, err error) { + pathp, err := syscall.UTF16PtrFromString(path) + if err != nil { + return + } + var access, shareMode uint32 + if readOnly { + access = syscall.GENERIC_READ + shareMode = syscall.FILE_SHARE_READ + } else { + access = syscall.GENERIC_READ | syscall.GENERIC_WRITE + } + fd, err := syscall.CreateFile(pathp, access, shareMode, nil, syscall.OPEN_EXISTING, syscall.FILE_ATTRIBUTE_NORMAL, 0) + if err == syscall.ERROR_FILE_NOT_FOUND { + fd, err = syscall.CreateFile(pathp, access, shareMode, nil, syscall.OPEN_ALWAYS, syscall.FILE_ATTRIBUTE_NORMAL, 0) + } + if err != nil { + return + } + fl = &windowsFileLock{fd: fd} + return +} + +func moveFileEx(from *uint16, to *uint16, flags uint32) error { + r1, _, e1 := syscall.Syscall(procMoveFileExW.Addr(), 3, uintptr(unsafe.Pointer(from)), uintptr(unsafe.Pointer(to)), uintptr(flags)) + if r1 == 0 { + if e1 != 0 { + return error(e1) + } + return syscall.EINVAL + } + return nil +} + +func rename(oldpath, newpath string) error { + from, err := syscall.UTF16PtrFromString(oldpath) + if err != nil { + return err + } + to, err := syscall.UTF16PtrFromString(newpath) + if err != nil { + return err + } + return moveFileEx(from, to, _MOVEFILE_REPLACE_EXISTING) +} + +func syncDir(name string) error { return nil } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go new file mode 100644 index 0000000..838f1be --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go @@ -0,0 +1,222 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "bytes" + "os" + "sync" +) + +const typeShift = 4 + +// Verify at compile-time that typeShift is large enough to cover all FileType +// values by confirming that 0 == 0. +var _ [0]struct{} = [TypeAll >> typeShift]struct{}{} + +type memStorageLock struct { + ms *memStorage +} + +func (lock *memStorageLock) Unlock() { + ms := lock.ms + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.slock == lock { + ms.slock = nil + } + return +} + +// memStorage is a memory-backed storage. +type memStorage struct { + mu sync.Mutex + slock *memStorageLock + files map[uint64]*memFile + meta FileDesc +} + +// NewMemStorage returns a new memory-backed storage implementation. +func NewMemStorage() Storage { + return &memStorage{ + files: make(map[uint64]*memFile), + } +} + +func (ms *memStorage) Lock() (Locker, error) { + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.slock != nil { + return nil, ErrLocked + } + ms.slock = &memStorageLock{ms: ms} + return ms.slock, nil +} + +func (*memStorage) Log(str string) {} + +func (ms *memStorage) SetMeta(fd FileDesc) error { + if !FileDescOk(fd) { + return ErrInvalidFile + } + + ms.mu.Lock() + ms.meta = fd + ms.mu.Unlock() + return nil +} + +func (ms *memStorage) GetMeta() (FileDesc, error) { + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.meta.Zero() { + return FileDesc{}, os.ErrNotExist + } + return ms.meta, nil +} + +func (ms *memStorage) List(ft FileType) ([]FileDesc, error) { + ms.mu.Lock() + var fds []FileDesc + for x := range ms.files { + fd := unpackFile(x) + if fd.Type&ft != 0 { + fds = append(fds, fd) + } + } + ms.mu.Unlock() + return fds, nil +} + +func (ms *memStorage) Open(fd FileDesc) (Reader, error) { + if !FileDescOk(fd) { + return nil, ErrInvalidFile + } + + ms.mu.Lock() + defer ms.mu.Unlock() + if m, exist := ms.files[packFile(fd)]; exist { + if m.open { + return nil, errFileOpen + } + m.open = true + return &memReader{Reader: bytes.NewReader(m.Bytes()), ms: ms, m: m}, nil + } + return nil, os.ErrNotExist +} + +func (ms *memStorage) Create(fd FileDesc) (Writer, error) { + if !FileDescOk(fd) { + return nil, ErrInvalidFile + } + + x := packFile(fd) + ms.mu.Lock() + defer ms.mu.Unlock() + m, exist := ms.files[x] + if exist { + if m.open { + return nil, errFileOpen + } + m.Reset() + } else { + m = &memFile{} + ms.files[x] = m + } + m.open = true + return &memWriter{memFile: m, ms: ms}, nil +} + +func (ms *memStorage) Remove(fd FileDesc) error { + if !FileDescOk(fd) { + return ErrInvalidFile + } + + x := packFile(fd) + ms.mu.Lock() + defer ms.mu.Unlock() + if _, exist := ms.files[x]; exist { + delete(ms.files, x) + return nil + } + return os.ErrNotExist +} + +func (ms *memStorage) Rename(oldfd, newfd FileDesc) error { + if !FileDescOk(oldfd) || !FileDescOk(newfd) { + return ErrInvalidFile + } + if oldfd == newfd { + return nil + } + + oldx := packFile(oldfd) + newx := packFile(newfd) + ms.mu.Lock() + defer ms.mu.Unlock() + oldm, exist := ms.files[oldx] + if !exist { + return os.ErrNotExist + } + newm, exist := ms.files[newx] + if (exist && newm.open) || oldm.open { + return errFileOpen + } + delete(ms.files, oldx) + ms.files[newx] = oldm + return nil +} + +func (*memStorage) Close() error { return nil } + +type memFile struct { + bytes.Buffer + open bool +} + +type memReader struct { + *bytes.Reader + ms *memStorage + m *memFile + closed bool +} + +func (mr *memReader) Close() error { + mr.ms.mu.Lock() + defer mr.ms.mu.Unlock() + if mr.closed { + return ErrClosed + } + mr.m.open = false + return nil +} + +type memWriter struct { + *memFile + ms *memStorage + closed bool +} + +func (*memWriter) Sync() error { return nil } + +func (mw *memWriter) Close() error { + mw.ms.mu.Lock() + defer mw.ms.mu.Unlock() + if mw.closed { + return ErrClosed + } + mw.memFile.open = false + return nil +} + +func packFile(fd FileDesc) uint64 { + return uint64(fd.Num)<> typeShift)} +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage_test.go new file mode 100644 index 0000000..bb0a19d --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/mem_storage_test.go @@ -0,0 +1,117 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "bytes" + "fmt" + "testing" +) + +func TestMemStorage(t *testing.T) { + m := NewMemStorage() + + l, err := m.Lock() + if err != nil { + t.Fatal("storage lock failed(1): ", err) + } + _, err = m.Lock() + if err == nil { + t.Fatal("expect error for second storage lock attempt") + } else { + t.Logf("storage lock got error: %s (expected)", err) + } + l.Unlock() + _, err = m.Lock() + if err != nil { + t.Fatal("storage lock failed(2): ", err) + } + + w, err := m.Create(FileDesc{TypeTable, 1}) + if err != nil { + t.Fatal("Storage.Create: ", err) + } + w.Write([]byte("abc")) + w.Close() + if fds, _ := m.List(TypeAll); len(fds) != 1 { + t.Fatal("invalid GetFiles len") + } + buf := new(bytes.Buffer) + r, err := m.Open(FileDesc{TypeTable, 1}) + if err != nil { + t.Fatal("Open: got error: ", err) + } + buf.ReadFrom(r) + r.Close() + if got := buf.String(); got != "abc" { + t.Fatalf("Read: invalid value, want=abc got=%s", got) + } + if _, err := m.Open(FileDesc{TypeTable, 1}); err != nil { + t.Fatal("Open: got error: ", err) + } + if _, err := m.Open(FileDesc{TypeTable, 1}); err == nil { + t.Fatal("expecting error") + } + m.Remove(FileDesc{TypeTable, 1}) + if fds, _ := m.List(TypeAll); len(fds) != 0 { + t.Fatal("invalid GetFiles len", len(fds)) + } + if _, err := m.Open(FileDesc{TypeTable, 1}); err == nil { + t.Fatal("expecting error") + } +} + +func TestMemStorageRename(t *testing.T) { + fd1 := FileDesc{Type: TypeTable, Num: 1} + fd2 := FileDesc{Type: TypeTable, Num: 2} + + m := NewMemStorage() + w, err := m.Create(fd1) + if err != nil { + t.Fatalf("Storage.Create: %v", err) + } + + fmt.Fprintf(w, "abc") + w.Close() + + rd, err := m.Open(fd1) + if err != nil { + t.Fatalf("Storage.Open(%v): %v", fd1, err) + } + rd.Close() + + fds, err := m.List(TypeAll) + if err != nil { + t.Fatalf("Storage.List: %v", err) + } + for _, fd := range fds { + if !FileDescOk(fd) { + t.Errorf("Storage.List -> FileDescOk(%q)", fd) + } + } + + err = m.Rename(fd1, fd2) + if err != nil { + t.Fatalf("Storage.Rename: %v", err) + } + + rd, err = m.Open(fd2) + if err != nil { + t.Fatalf("Storage.Open(%v): %v", fd2, err) + } + rd.Close() + + fds, err = m.List(TypeAll) + if err != nil { + t.Fatalf("Storage.List: %v", err) + } + for _, fd := range fds { + if !FileDescOk(fd) { + t.Errorf("Storage.List -> FileDescOk(%q)", fd) + } + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/storage/storage.go b/vendor/github.com/syndtr/goleveldb/leveldb/storage/storage.go new file mode 100644 index 0000000..4e4a724 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/storage/storage.go @@ -0,0 +1,187 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package storage provides storage abstraction for LevelDB. +package storage + +import ( + "errors" + "fmt" + "io" +) + +// FileType represent a file type. +type FileType int + +// File types. +const ( + TypeManifest FileType = 1 << iota + TypeJournal + TypeTable + TypeTemp + + TypeAll = TypeManifest | TypeJournal | TypeTable | TypeTemp +) + +func (t FileType) String() string { + switch t { + case TypeManifest: + return "manifest" + case TypeJournal: + return "journal" + case TypeTable: + return "table" + case TypeTemp: + return "temp" + } + return fmt.Sprintf("", t) +} + +// Common error. +var ( + ErrInvalidFile = errors.New("leveldb/storage: invalid file for argument") + ErrLocked = errors.New("leveldb/storage: already locked") + ErrClosed = errors.New("leveldb/storage: closed") +) + +// ErrCorrupted is the type that wraps errors that indicate corruption of +// a file. Package storage has its own type instead of using +// errors.ErrCorrupted to prevent circular import. +type ErrCorrupted struct { + Fd FileDesc + Err error +} + +func isCorrupted(err error) bool { + switch err.(type) { + case *ErrCorrupted: + return true + } + return false +} + +func (e *ErrCorrupted) Error() string { + if !e.Fd.Zero() { + return fmt.Sprintf("%v [file=%v]", e.Err, e.Fd) + } + return e.Err.Error() +} + +// Syncer is the interface that wraps basic Sync method. +type Syncer interface { + // Sync commits the current contents of the file to stable storage. + Sync() error +} + +// Reader is the interface that groups the basic Read, Seek, ReadAt and Close +// methods. +type Reader interface { + io.ReadSeeker + io.ReaderAt + io.Closer +} + +// Writer is the interface that groups the basic Write, Sync and Close +// methods. +type Writer interface { + io.WriteCloser + Syncer +} + +// Locker is the interface that wraps Unlock method. +type Locker interface { + Unlock() +} + +// FileDesc is a 'file descriptor'. +type FileDesc struct { + Type FileType + Num int64 +} + +func (fd FileDesc) String() string { + switch fd.Type { + case TypeManifest: + return fmt.Sprintf("MANIFEST-%06d", fd.Num) + case TypeJournal: + return fmt.Sprintf("%06d.log", fd.Num) + case TypeTable: + return fmt.Sprintf("%06d.ldb", fd.Num) + case TypeTemp: + return fmt.Sprintf("%06d.tmp", fd.Num) + default: + return fmt.Sprintf("%#x-%d", fd.Type, fd.Num) + } +} + +// Zero returns true if fd == (FileDesc{}). +func (fd FileDesc) Zero() bool { + return fd == (FileDesc{}) +} + +// FileDescOk returns true if fd is a valid 'file descriptor'. +func FileDescOk(fd FileDesc) bool { + switch fd.Type { + case TypeManifest: + case TypeJournal: + case TypeTable: + case TypeTemp: + default: + return false + } + return fd.Num >= 0 +} + +// Storage is the storage. A storage instance must be safe for concurrent use. +type Storage interface { + // Lock locks the storage. Any subsequent attempt to call Lock will fail + // until the last lock released. + // Caller should call Unlock method after use. + Lock() (Locker, error) + + // Log logs a string. This is used for logging. + // An implementation may write to a file, stdout or simply do nothing. + Log(str string) + + // SetMeta store 'file descriptor' that can later be acquired using GetMeta + // method. The 'file descriptor' should point to a valid file. + // SetMeta should be implemented in such way that changes should happen + // atomically. + SetMeta(fd FileDesc) error + + // GetMeta returns 'file descriptor' stored in meta. The 'file descriptor' + // can be updated using SetMeta method. + // Returns os.ErrNotExist if meta doesn't store any 'file descriptor', or + // 'file descriptor' point to nonexistent file. + GetMeta() (FileDesc, error) + + // List returns file descriptors that match the given file types. + // The file types may be OR'ed together. + List(ft FileType) ([]FileDesc, error) + + // Open opens file with the given 'file descriptor' read-only. + // Returns os.ErrNotExist error if the file does not exist. + // Returns ErrClosed if the underlying storage is closed. + Open(fd FileDesc) (Reader, error) + + // Create creates file with the given 'file descriptor', truncate if already + // exist and opens write-only. + // Returns ErrClosed if the underlying storage is closed. + Create(fd FileDesc) (Writer, error) + + // Remove removes file with the given 'file descriptor'. + // Returns ErrClosed if the underlying storage is closed. + Remove(fd FileDesc) error + + // Rename renames file from oldfd to newfd. + // Returns ErrClosed if the underlying storage is closed. + Rename(oldfd, newfd FileDesc) error + + // Close closes the storage. + // It is valid to call Close multiple times. Other methods should not be + // called after the storage has been closed. + Close() error +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table.go b/vendor/github.com/syndtr/goleveldb/leveldb/table.go new file mode 100644 index 0000000..b7759b2 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table.go @@ -0,0 +1,600 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/table" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// tFile holds basic information about a table. +type tFile struct { + fd storage.FileDesc + seekLeft int32 + size int64 + imin, imax internalKey +} + +// Returns true if given key is after largest key of this table. +func (t *tFile) after(icmp *iComparer, ukey []byte) bool { + return ukey != nil && icmp.uCompare(ukey, t.imax.ukey()) > 0 +} + +// Returns true if given key is before smallest key of this table. +func (t *tFile) before(icmp *iComparer, ukey []byte) bool { + return ukey != nil && icmp.uCompare(ukey, t.imin.ukey()) < 0 +} + +// Returns true if given key range overlaps with this table key range. +func (t *tFile) overlaps(icmp *iComparer, umin, umax []byte) bool { + return !t.after(icmp, umin) && !t.before(icmp, umax) +} + +// Cosumes one seek and return current seeks left. +func (t *tFile) consumeSeek() int32 { + return atomic.AddInt32(&t.seekLeft, -1) +} + +// Creates new tFile. +func newTableFile(fd storage.FileDesc, size int64, imin, imax internalKey) *tFile { + f := &tFile{ + fd: fd, + size: size, + imin: imin, + imax: imax, + } + + // We arrange to automatically compact this file after + // a certain number of seeks. Let's assume: + // (1) One seek costs 10ms + // (2) Writing or reading 1MB costs 10ms (100MB/s) + // (3) A compaction of 1MB does 25MB of IO: + // 1MB read from this level + // 10-12MB read from next level (boundaries may be misaligned) + // 10-12MB written to next level + // This implies that 25 seeks cost the same as the compaction + // of 1MB of data. I.e., one seek costs approximately the + // same as the compaction of 40KB of data. We are a little + // conservative and allow approximately one seek for every 16KB + // of data before triggering a compaction. + f.seekLeft = int32(size / 16384) + if f.seekLeft < 100 { + f.seekLeft = 100 + } + + return f +} + +func tableFileFromRecord(r atRecord) *tFile { + return newTableFile(storage.FileDesc{Type: storage.TypeTable, Num: r.num}, r.size, r.imin, r.imax) +} + +// tFiles hold multiple tFile. +type tFiles []*tFile + +func (tf tFiles) Len() int { return len(tf) } +func (tf tFiles) Swap(i, j int) { tf[i], tf[j] = tf[j], tf[i] } + +func (tf tFiles) nums() string { + x := "[ " + for i, f := range tf { + if i != 0 { + x += ", " + } + x += fmt.Sprint(f.fd.Num) + } + x += " ]" + return x +} + +// Returns true if i smallest key is less than j. +// This used for sort by key in ascending order. +func (tf tFiles) lessByKey(icmp *iComparer, i, j int) bool { + a, b := tf[i], tf[j] + n := icmp.Compare(a.imin, b.imin) + if n == 0 { + return a.fd.Num < b.fd.Num + } + return n < 0 +} + +// Returns true if i file number is greater than j. +// This used for sort by file number in descending order. +func (tf tFiles) lessByNum(i, j int) bool { + return tf[i].fd.Num > tf[j].fd.Num +} + +// Sorts tables by key in ascending order. +func (tf tFiles) sortByKey(icmp *iComparer) { + sort.Sort(&tFilesSortByKey{tFiles: tf, icmp: icmp}) +} + +// Sorts tables by file number in descending order. +func (tf tFiles) sortByNum() { + sort.Sort(&tFilesSortByNum{tFiles: tf}) +} + +// Returns sum of all tables size. +func (tf tFiles) size() (sum int64) { + for _, t := range tf { + sum += t.size + } + return sum +} + +// Searches smallest index of tables whose its smallest +// key is after or equal with given key. +func (tf tFiles) searchMin(icmp *iComparer, ikey internalKey) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.Compare(tf[i].imin, ikey) >= 0 + }) +} + +// Searches smallest index of tables whose its largest +// key is after or equal with given key. +func (tf tFiles) searchMax(icmp *iComparer, ikey internalKey) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.Compare(tf[i].imax, ikey) >= 0 + }) +} + +// Searches smallest index of tables whose its file number +// is smaller than the given number. +func (tf tFiles) searchNumLess(num int64) int { + return sort.Search(len(tf), func(i int) bool { + return tf[i].fd.Num < num + }) +} + +// Searches smallest index of tables whose its smallest +// key is after the given key. +func (tf tFiles) searchMinUkey(icmp *iComparer, umin []byte) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.ucmp.Compare(tf[i].imin.ukey(), umin) > 0 + }) +} + +// Searches smallest index of tables whose its largest +// key is after the given key. +func (tf tFiles) searchMaxUkey(icmp *iComparer, umax []byte) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.ucmp.Compare(tf[i].imax.ukey(), umax) > 0 + }) +} + +// Returns true if given key range overlaps with one or more +// tables key range. If unsorted is true then binary search will not be used. +func (tf tFiles) overlaps(icmp *iComparer, umin, umax []byte, unsorted bool) bool { + if unsorted { + // Check against all files. + for _, t := range tf { + if t.overlaps(icmp, umin, umax) { + return true + } + } + return false + } + + i := 0 + if len(umin) > 0 { + // Find the earliest possible internal key for min. + i = tf.searchMax(icmp, makeInternalKey(nil, umin, keyMaxSeq, keyTypeSeek)) + } + if i >= len(tf) { + // Beginning of range is after all files, so no overlap. + return false + } + return !tf[i].before(icmp, umax) +} + +// Returns tables whose its key range overlaps with given key range. +// Range will be expanded if ukey found hop across tables. +// If overlapped is true then the search will be restarted if umax +// expanded. +// The dst content will be overwritten. +func (tf tFiles) getOverlaps(dst tFiles, icmp *iComparer, umin, umax []byte, overlapped bool) tFiles { + // Short circuit if tf is empty + if len(tf) == 0 { + return nil + } + // For non-zero levels, there is no ukey hop across at all. + // And what's more, the files in these levels are strictly sorted, + // so use binary search instead of heavy traverse. + if !overlapped { + var begin, end int + // Determine the begin index of the overlapped file + if umin != nil { + index := tf.searchMinUkey(icmp, umin) + if index == 0 { + begin = 0 + } else if bytes.Compare(tf[index-1].imax.ukey(), umin) >= 0 { + // The min ukey overlaps with the index-1 file, expand it. + begin = index - 1 + } else { + begin = index + } + } + // Determine the end index of the overlapped file + if umax != nil { + index := tf.searchMaxUkey(icmp, umax) + if index == len(tf) { + end = len(tf) + } else if bytes.Compare(tf[index].imin.ukey(), umax) <= 0 { + // The max ukey overlaps with the index file, expand it. + end = index + 1 + } else { + end = index + } + } else { + end = len(tf) + } + // Ensure the overlapped file indexes are valid. + if begin >= end { + return nil + } + dst = make([]*tFile, end-begin) + copy(dst, tf[begin:end]) + return dst + } + + dst = dst[:0] + for i := 0; i < len(tf); { + t := tf[i] + if t.overlaps(icmp, umin, umax) { + if umin != nil && icmp.uCompare(t.imin.ukey(), umin) < 0 { + umin = t.imin.ukey() + dst = dst[:0] + i = 0 + continue + } else if umax != nil && icmp.uCompare(t.imax.ukey(), umax) > 0 { + umax = t.imax.ukey() + // Restart search if it is overlapped. + dst = dst[:0] + i = 0 + continue + } + + dst = append(dst, t) + } + i++ + } + + return dst +} + +// Returns tables key range. +func (tf tFiles) getRange(icmp *iComparer) (imin, imax internalKey) { + for i, t := range tf { + if i == 0 { + imin, imax = t.imin, t.imax + continue + } + if icmp.Compare(t.imin, imin) < 0 { + imin = t.imin + } + if icmp.Compare(t.imax, imax) > 0 { + imax = t.imax + } + } + + return +} + +// Creates iterator index from tables. +func (tf tFiles) newIndexIterator(tops *tOps, icmp *iComparer, slice *util.Range, ro *opt.ReadOptions) iterator.IteratorIndexer { + if slice != nil { + var start, limit int + if slice.Start != nil { + start = tf.searchMax(icmp, internalKey(slice.Start)) + } + if slice.Limit != nil { + limit = tf.searchMin(icmp, internalKey(slice.Limit)) + } else { + limit = tf.Len() + } + tf = tf[start:limit] + } + return iterator.NewArrayIndexer(&tFilesArrayIndexer{ + tFiles: tf, + tops: tops, + icmp: icmp, + slice: slice, + ro: ro, + }) +} + +// Tables iterator index. +type tFilesArrayIndexer struct { + tFiles + tops *tOps + icmp *iComparer + slice *util.Range + ro *opt.ReadOptions +} + +func (a *tFilesArrayIndexer) Search(key []byte) int { + return a.searchMax(a.icmp, internalKey(key)) +} + +func (a *tFilesArrayIndexer) Get(i int) iterator.Iterator { + if i == 0 || i == a.Len()-1 { + return a.tops.newIterator(a.tFiles[i], a.slice, a.ro) + } + return a.tops.newIterator(a.tFiles[i], nil, a.ro) +} + +// Helper type for sortByKey. +type tFilesSortByKey struct { + tFiles + icmp *iComparer +} + +func (x *tFilesSortByKey) Less(i, j int) bool { + return x.lessByKey(x.icmp, i, j) +} + +// Helper type for sortByNum. +type tFilesSortByNum struct { + tFiles +} + +func (x *tFilesSortByNum) Less(i, j int) bool { + return x.lessByNum(i, j) +} + +// Table operations. +type tOps struct { + s *session + noSync bool + evictRemoved bool + cache *cache.Cache + bcache *cache.Cache + bpool *util.BufferPool +} + +// Creates an empty table and returns table writer. +func (t *tOps) create() (*tWriter, error) { + fd := storage.FileDesc{Type: storage.TypeTable, Num: t.s.allocFileNum()} + fw, err := t.s.stor.Create(fd) + if err != nil { + return nil, err + } + return &tWriter{ + t: t, + fd: fd, + w: fw, + tw: table.NewWriter(fw, t.s.o.Options), + }, nil +} + +// Builds table from src iterator. +func (t *tOps) createFrom(src iterator.Iterator) (f *tFile, n int, err error) { + w, err := t.create() + if err != nil { + return + } + + defer func() { + if err != nil { + w.drop() + } + }() + + for src.Next() { + err = w.append(src.Key(), src.Value()) + if err != nil { + return + } + } + err = src.Error() + if err != nil { + return + } + + n = w.tw.EntriesLen() + f, err = w.finish() + return +} + +// Opens table. It returns a cache handle, which should +// be released after use. +func (t *tOps) open(f *tFile) (ch *cache.Handle, err error) { + ch = t.cache.Get(0, uint64(f.fd.Num), func() (size int, value cache.Value) { + var r storage.Reader + r, err = t.s.stor.Open(f.fd) + if err != nil { + return 0, nil + } + + var bcache *cache.NamespaceGetter + if t.bcache != nil { + bcache = &cache.NamespaceGetter{Cache: t.bcache, NS: uint64(f.fd.Num)} + } + + var tr *table.Reader + tr, err = table.NewReader(r, f.size, f.fd, bcache, t.bpool, t.s.o.Options) + if err != nil { + r.Close() + return 0, nil + } + return 1, tr + + }) + if ch == nil && err == nil { + err = ErrClosed + } + return +} + +// Finds key/value pair whose key is greater than or equal to the +// given key. +func (t *tOps) find(f *tFile, key []byte, ro *opt.ReadOptions) (rkey, rvalue []byte, err error) { + ch, err := t.open(f) + if err != nil { + return nil, nil, err + } + defer ch.Release() + return ch.Value().(*table.Reader).Find(key, true, ro) +} + +// Finds key that is greater than or equal to the given key. +func (t *tOps) findKey(f *tFile, key []byte, ro *opt.ReadOptions) (rkey []byte, err error) { + ch, err := t.open(f) + if err != nil { + return nil, err + } + defer ch.Release() + return ch.Value().(*table.Reader).FindKey(key, true, ro) +} + +// Returns approximate offset of the given key. +func (t *tOps) offsetOf(f *tFile, key []byte) (offset int64, err error) { + ch, err := t.open(f) + if err != nil { + return + } + defer ch.Release() + return ch.Value().(*table.Reader).OffsetOf(key) +} + +// Creates an iterator from the given table. +func (t *tOps) newIterator(f *tFile, slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + ch, err := t.open(f) + if err != nil { + return iterator.NewEmptyIterator(err) + } + iter := ch.Value().(*table.Reader).NewIterator(slice, ro) + iter.SetReleaser(ch) + return iter +} + +// Removes table from persistent storage. It waits until +// no one use the the table. +func (t *tOps) remove(fd storage.FileDesc) { + t.cache.Delete(0, uint64(fd.Num), func() { + if err := t.s.stor.Remove(fd); err != nil { + t.s.logf("table@remove removing @%d %q", fd.Num, err) + } else { + t.s.logf("table@remove removed @%d", fd.Num) + } + if t.evictRemoved && t.bcache != nil { + t.bcache.EvictNS(uint64(fd.Num)) + } + // Try to reuse file num, useful for discarded transaction. + t.s.reuseFileNum(fd.Num) + }) +} + +// Closes the table ops instance. It will close all tables, +// regadless still used or not. +func (t *tOps) close() { + t.bpool.Close() + t.cache.Close() + if t.bcache != nil { + t.bcache.CloseWeak() + } +} + +// Creates new initialized table ops instance. +func newTableOps(s *session) *tOps { + var ( + cacher cache.Cacher + bcache *cache.Cache + bpool *util.BufferPool + ) + if s.o.GetOpenFilesCacheCapacity() > 0 { + cacher = cache.NewLRU(s.o.GetOpenFilesCacheCapacity()) + } + if !s.o.GetDisableBlockCache() { + var bcacher cache.Cacher + if s.o.GetBlockCacheCapacity() > 0 { + bcacher = s.o.GetBlockCacher().New(s.o.GetBlockCacheCapacity()) + } + bcache = cache.NewCache(bcacher) + } + if !s.o.GetDisableBufferPool() { + bpool = util.NewBufferPool(s.o.GetBlockSize() + 5) + } + return &tOps{ + s: s, + noSync: s.o.GetNoSync(), + evictRemoved: s.o.GetBlockCacheEvictRemoved(), + cache: cache.NewCache(cacher), + bcache: bcache, + bpool: bpool, + } +} + +// tWriter wraps the table writer. It keep track of file descriptor +// and added key range. +type tWriter struct { + t *tOps + + fd storage.FileDesc + w storage.Writer + tw *table.Writer + + first, last []byte +} + +// Append key/value pair to the table. +func (w *tWriter) append(key, value []byte) error { + if w.first == nil { + w.first = append([]byte{}, key...) + } + w.last = append(w.last[:0], key...) + return w.tw.Append(key, value) +} + +// Returns true if the table is empty. +func (w *tWriter) empty() bool { + return w.first == nil +} + +// Closes the storage.Writer. +func (w *tWriter) close() { + if w.w != nil { + w.w.Close() + w.w = nil + } +} + +// Finalizes the table and returns table file. +func (w *tWriter) finish() (f *tFile, err error) { + defer w.close() + err = w.tw.Close() + if err != nil { + return + } + if !w.t.noSync { + err = w.w.Sync() + if err != nil { + return + } + } + f = newTableFile(w.fd, int64(w.tw.BytesLen()), internalKey(w.first), internalKey(w.last)) + return +} + +// Drops the table. +func (w *tWriter) drop() { + w.close() + w.t.s.stor.Remove(w.fd) + w.t.s.reuseFileNum(w.fd.Num) + w.tw = nil + w.first = nil + w.last = nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/block_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/block_test.go new file mode 100644 index 0000000..00e6f9e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/block_test.go @@ -0,0 +1,139 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type blockTesting struct { + tr *Reader + b *block +} + +func (t *blockTesting) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.tr.newBlockIter(t.b, nil, slice, false) +} + +var _ = testutil.Defer(func() { + Describe("Block", func() { + Build := func(kv *testutil.KeyValue, restartInterval int) *blockTesting { + // Building the block. + bw := &blockWriter{ + restartInterval: restartInterval, + scratch: make([]byte, 30), + } + kv.Iterate(func(i int, key, value []byte) { + bw.append(key, value) + }) + bw.finish() + + // Opening the block. + data := bw.buf.Bytes() + restartsLen := int(binary.LittleEndian.Uint32(data[len(data)-4:])) + return &blockTesting{ + tr: &Reader{cmp: comparer.DefaultComparer}, + b: &block{ + data: data, + restartsLen: restartsLen, + restartsOffset: len(data) - (restartsLen+1)*4, + }, + } + } + + Describe("read test", func() { + for restartInterval := 1; restartInterval <= 5; restartInterval++ { + Describe(fmt.Sprintf("with restart interval of %d", restartInterval), func() { + kv := &testutil.KeyValue{} + Text := func() string { + return fmt.Sprintf("and %d keys", kv.Len()) + } + + Test := func() { + // Make block. + br := Build(kv, restartInterval) + // Do testing. + testutil.KeyValueTesting(nil, kv.Clone(), br, nil, nil) + } + + Describe(Text(), Test) + + kv.PutString("", "empty") + Describe(Text(), Test) + + kv.PutString("a1", "foo") + Describe(Text(), Test) + + kv.PutString("a2", "v") + Describe(Text(), Test) + + kv.PutString("a3qqwrkks", "hello") + Describe(Text(), Test) + + kv.PutString("a4", "bar") + Describe(Text(), Test) + + kv.PutString("a5111111", "v5") + kv.PutString("a6", "") + kv.PutString("a7", "v7") + kv.PutString("a8", "vvvvvvvvvvvvvvvvvvvvvv8") + kv.PutString("b", "v9") + kv.PutString("c9", "v9") + kv.PutString("c91", "v9") + kv.PutString("d0", "v9") + Describe(Text(), Test) + }) + } + }) + + Describe("out-of-bound slice test", func() { + kv := &testutil.KeyValue{} + kv.PutString("k1", "v1") + kv.PutString("k2", "v2") + kv.PutString("k3abcdefgg", "v3") + kv.PutString("k4", "v4") + kv.PutString("k5", "v5") + for restartInterval := 1; restartInterval <= 5; restartInterval++ { + Describe(fmt.Sprintf("with restart interval of %d", restartInterval), func() { + // Make block. + bt := Build(kv, restartInterval) + + Test := func(r *util.Range) func(done Done) { + return func(done Done) { + iter := bt.TestNewIterator(r) + Expect(iter.Error()).ShouldNot(HaveOccurred()) + + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: iter, + } + + testutil.DoIteratorTesting(&t) + iter.Release() + done <- true + } + } + + It("Should do iterations and seeks correctly #0", + Test(&util.Range{Start: []byte("k0"), Limit: []byte("k6")}), 2.0) + + It("Should do iterations and seeks correctly #1", + Test(&util.Range{Start: []byte(""), Limit: []byte("zzzzzzz")}), 2.0) + }) + } + }) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go new file mode 100644 index 0000000..496feb6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go @@ -0,0 +1,1139 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "fmt" + "io" + "sort" + "strings" + "sync" + + "github.com/golang/snappy" + + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Reader errors. +var ( + ErrNotFound = errors.ErrNotFound + ErrReaderReleased = errors.New("leveldb/table: reader released") + ErrIterReleased = errors.New("leveldb/table: iterator released") +) + +// ErrCorrupted describes error due to corruption. This error will be wrapped +// with errors.ErrCorrupted. +type ErrCorrupted struct { + Pos int64 + Size int64 + Kind string + Reason string +} + +func (e *ErrCorrupted) Error() string { + return fmt.Sprintf("leveldb/table: corruption on %s (pos=%d): %s", e.Kind, e.Pos, e.Reason) +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +type block struct { + bpool *util.BufferPool + bh blockHandle + data []byte + restartsLen int + restartsOffset int +} + +func (b *block) seek(cmp comparer.Comparer, rstart, rlimit int, key []byte) (index, offset int, err error) { + index = sort.Search(b.restartsLen-rstart-(b.restartsLen-rlimit), func(i int) bool { + offset := int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*(rstart+i):])) + offset++ // shared always zero, since this is a restart point + v1, n1 := binary.Uvarint(b.data[offset:]) // key length + _, n2 := binary.Uvarint(b.data[offset+n1:]) // value length + m := offset + n1 + n2 + return cmp.Compare(b.data[m:m+int(v1)], key) > 0 + }) + rstart - 1 + if index < rstart { + // The smallest key is greater-than key sought. + index = rstart + } + offset = int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*index:])) + return +} + +func (b *block) restartIndex(rstart, rlimit, offset int) int { + return sort.Search(b.restartsLen-rstart-(b.restartsLen-rlimit), func(i int) bool { + return int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*(rstart+i):])) > offset + }) + rstart - 1 +} + +func (b *block) restartOffset(index int) int { + return int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*index:])) +} + +func (b *block) entry(offset int) (key, value []byte, nShared, n int, err error) { + if offset >= b.restartsOffset { + if offset != b.restartsOffset { + err = &ErrCorrupted{Reason: "entries offset not aligned"} + } + return + } + v0, n0 := binary.Uvarint(b.data[offset:]) // Shared prefix length + v1, n1 := binary.Uvarint(b.data[offset+n0:]) // Key length + v2, n2 := binary.Uvarint(b.data[offset+n0+n1:]) // Value length + m := n0 + n1 + n2 + n = m + int(v1) + int(v2) + if n0 <= 0 || n1 <= 0 || n2 <= 0 || offset+n > b.restartsOffset { + err = &ErrCorrupted{Reason: "entries corrupted"} + return + } + key = b.data[offset+m : offset+m+int(v1)] + value = b.data[offset+m+int(v1) : offset+n] + nShared = int(v0) + return +} + +func (b *block) Release() { + b.bpool.Put(b.data) + b.bpool = nil + b.data = nil +} + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +type blockIter struct { + tr *Reader + block *block + blockReleaser util.Releaser + releaser util.Releaser + key, value []byte + offset int + // Previous offset, only filled by Next. + prevOffset int + prevNode []int + prevKeys []byte + restartIndex int + // Iterator direction. + dir dir + // Restart index slice range. + riStart int + riLimit int + // Offset slice range. + offsetStart int + offsetRealStart int + offsetLimit int + // Error. + err error +} + +func (i *blockIter) sErr(err error) { + i.err = err + i.key = nil + i.value = nil + i.prevNode = nil + i.prevKeys = nil +} + +func (i *blockIter) reset() { + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.restartIndex = i.riStart + i.offset = i.offsetStart + i.dir = dirSOI + i.key = i.key[:0] + i.value = nil +} + +func (i *blockIter) isFirst() bool { + switch i.dir { + case dirForward: + return i.prevOffset == i.offsetRealStart + case dirBackward: + return len(i.prevNode) == 1 && i.restartIndex == i.riStart + } + return false +} + +func (i *blockIter) isLast() bool { + switch i.dir { + case dirForward, dirBackward: + return i.offset == i.offsetLimit + } + return false +} + +func (i *blockIter) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.dir = dirSOI + return i.Next() +} + +func (i *blockIter) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.dir = dirEOI + return i.Prev() +} + +func (i *blockIter) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + ri, offset, err := i.block.seek(i.tr.cmp, i.riStart, i.riLimit, key) + if err != nil { + i.sErr(err) + return false + } + i.restartIndex = ri + i.offset = max(i.offsetStart, offset) + if i.dir == dirSOI || i.dir == dirEOI { + i.dir = dirForward + } + for i.Next() { + if i.tr.cmp.Compare(i.key, key) >= 0 { + return true + } + } + return false +} + +func (i *blockIter) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirSOI { + i.restartIndex = i.riStart + i.offset = i.offsetStart + } else if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + for i.offset < i.offsetRealStart { + key, value, nShared, n, err := i.block.entry(i.offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if n == 0 { + i.dir = dirEOI + return false + } + i.key = append(i.key[:nShared], key...) + i.value = value + i.offset += n + } + if i.offset >= i.offsetLimit { + i.dir = dirEOI + if i.offset != i.offsetLimit { + i.sErr(i.tr.newErrCorruptedBH(i.block.bh, "entries offset not aligned")) + } + return false + } + key, value, nShared, n, err := i.block.entry(i.offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if n == 0 { + i.dir = dirEOI + return false + } + i.key = append(i.key[:nShared], key...) + i.value = value + i.prevOffset = i.offset + i.offset += n + i.dir = dirForward + return true +} + +func (i *blockIter) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + var ri int + if i.dir == dirForward { + // Change direction. + i.offset = i.prevOffset + if i.offset == i.offsetRealStart { + i.dir = dirSOI + return false + } + ri = i.block.restartIndex(i.restartIndex, i.riLimit, i.offset) + i.dir = dirBackward + } else if i.dir == dirEOI { + // At the end of iterator. + i.restartIndex = i.riLimit + i.offset = i.offsetLimit + if i.offset == i.offsetRealStart { + i.dir = dirSOI + return false + } + ri = i.riLimit - 1 + i.dir = dirBackward + } else if len(i.prevNode) == 1 { + // This is the end of a restart range. + i.offset = i.prevNode[0] + i.prevNode = i.prevNode[:0] + if i.restartIndex == i.riStart { + i.dir = dirSOI + return false + } + i.restartIndex-- + ri = i.restartIndex + } else { + // In the middle of restart range, get from cache. + n := len(i.prevNode) - 3 + node := i.prevNode[n:] + i.prevNode = i.prevNode[:n] + // Get the key. + ko := node[0] + i.key = append(i.key[:0], i.prevKeys[ko:]...) + i.prevKeys = i.prevKeys[:ko] + // Get the value. + vo := node[1] + vl := vo + node[2] + i.value = i.block.data[vo:vl] + i.offset = vl + return true + } + // Build entries cache. + i.key = i.key[:0] + i.value = nil + offset := i.block.restartOffset(ri) + if offset == i.offset { + ri-- + if ri < 0 { + i.dir = dirSOI + return false + } + offset = i.block.restartOffset(ri) + } + i.prevNode = append(i.prevNode, offset) + for { + key, value, nShared, n, err := i.block.entry(offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if offset >= i.offsetRealStart { + if i.value != nil { + // Appends 3 variables: + // 1. Previous keys offset + // 2. Value offset in the data block + // 3. Value length + i.prevNode = append(i.prevNode, len(i.prevKeys), offset-len(i.value), len(i.value)) + i.prevKeys = append(i.prevKeys, i.key...) + } + i.value = value + } + i.key = append(i.key[:nShared], key...) + offset += n + // Stop if target offset reached. + if offset >= i.offset { + if offset != i.offset { + i.sErr(i.tr.newErrCorruptedBH(i.block.bh, "entries offset not aligned")) + return false + } + + break + } + } + i.restartIndex = ri + i.offset = offset + return true +} + +func (i *blockIter) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.key +} + +func (i *blockIter) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.value +} + +func (i *blockIter) Release() { + if i.dir != dirReleased { + i.tr = nil + i.block = nil + i.prevNode = nil + i.prevKeys = nil + i.key = nil + i.value = nil + i.dir = dirReleased + if i.blockReleaser != nil { + i.blockReleaser.Release() + i.blockReleaser = nil + } + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + } +} + +func (i *blockIter) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *blockIter) Valid() bool { + return i.err == nil && (i.dir == dirBackward || i.dir == dirForward) +} + +func (i *blockIter) Error() error { + return i.err +} + +type filterBlock struct { + bpool *util.BufferPool + data []byte + oOffset int + baseLg uint + filtersNum int +} + +func (b *filterBlock) contains(filter filter.Filter, offset uint64, key []byte) bool { + i := int(offset >> b.baseLg) + if i < b.filtersNum { + o := b.data[b.oOffset+i*4:] + n := int(binary.LittleEndian.Uint32(o)) + m := int(binary.LittleEndian.Uint32(o[4:])) + if n < m && m <= b.oOffset { + return filter.Contains(b.data[n:m], key) + } else if n == m { + return false + } + } + return true +} + +func (b *filterBlock) Release() { + b.bpool.Put(b.data) + b.bpool = nil + b.data = nil +} + +type indexIter struct { + *blockIter + tr *Reader + slice *util.Range + // Options + fillCache bool +} + +func (i *indexIter) Get() iterator.Iterator { + value := i.Value() + if value == nil { + return nil + } + dataBH, n := decodeBlockHandle(value) + if n == 0 { + return iterator.NewEmptyIterator(i.tr.newErrCorruptedBH(i.tr.indexBH, "bad data block handle")) + } + + var slice *util.Range + if i.slice != nil && (i.blockIter.isFirst() || i.blockIter.isLast()) { + slice = i.slice + } + return i.tr.getDataIterErr(dataBH, slice, i.tr.verifyChecksum, i.fillCache) +} + +// Reader is a table reader. +type Reader struct { + mu sync.RWMutex + fd storage.FileDesc + reader io.ReaderAt + cache *cache.NamespaceGetter + err error + bpool *util.BufferPool + // Options + o *opt.Options + cmp comparer.Comparer + filter filter.Filter + verifyChecksum bool + + dataEnd int64 + metaBH, indexBH, filterBH blockHandle + indexBlock *block + filterBlock *filterBlock +} + +func (r *Reader) blockKind(bh blockHandle) string { + switch bh.offset { + case r.metaBH.offset: + return "meta-block" + case r.indexBH.offset: + return "index-block" + case r.filterBH.offset: + if r.filterBH.length > 0 { + return "filter-block" + } + } + return "data-block" +} + +func (r *Reader) newErrCorrupted(pos, size int64, kind, reason string) error { + return &errors.ErrCorrupted{Fd: r.fd, Err: &ErrCorrupted{Pos: pos, Size: size, Kind: kind, Reason: reason}} +} + +func (r *Reader) newErrCorruptedBH(bh blockHandle, reason string) error { + return r.newErrCorrupted(int64(bh.offset), int64(bh.length), r.blockKind(bh), reason) +} + +func (r *Reader) fixErrCorruptedBH(bh blockHandle, err error) error { + if cerr, ok := err.(*ErrCorrupted); ok { + cerr.Pos = int64(bh.offset) + cerr.Size = int64(bh.length) + cerr.Kind = r.blockKind(bh) + return &errors.ErrCorrupted{Fd: r.fd, Err: cerr} + } + return err +} + +func (r *Reader) readRawBlock(bh blockHandle, verifyChecksum bool) ([]byte, error) { + data := r.bpool.Get(int(bh.length + blockTrailerLen)) + if _, err := r.reader.ReadAt(data, int64(bh.offset)); err != nil && err != io.EOF { + return nil, err + } + + if verifyChecksum { + n := bh.length + 1 + checksum0 := binary.LittleEndian.Uint32(data[n:]) + checksum1 := util.NewCRC(data[:n]).Value() + if checksum0 != checksum1 { + r.bpool.Put(data) + return nil, r.newErrCorruptedBH(bh, fmt.Sprintf("checksum mismatch, want=%#x got=%#x", checksum0, checksum1)) + } + } + + switch data[bh.length] { + case blockTypeNoCompression: + data = data[:bh.length] + case blockTypeSnappyCompression: + decLen, err := snappy.DecodedLen(data[:bh.length]) + if err != nil { + r.bpool.Put(data) + return nil, r.newErrCorruptedBH(bh, err.Error()) + } + decData := r.bpool.Get(decLen) + decData, err = snappy.Decode(decData, data[:bh.length]) + r.bpool.Put(data) + if err != nil { + r.bpool.Put(decData) + return nil, r.newErrCorruptedBH(bh, err.Error()) + } + data = decData + default: + r.bpool.Put(data) + return nil, r.newErrCorruptedBH(bh, fmt.Sprintf("unknown compression type %#x", data[bh.length])) + } + return data, nil +} + +func (r *Reader) readBlock(bh blockHandle, verifyChecksum bool) (*block, error) { + data, err := r.readRawBlock(bh, verifyChecksum) + if err != nil { + return nil, err + } + restartsLen := int(binary.LittleEndian.Uint32(data[len(data)-4:])) + b := &block{ + bpool: r.bpool, + bh: bh, + data: data, + restartsLen: restartsLen, + restartsOffset: len(data) - (restartsLen+1)*4, + } + return b, nil +} + +func (r *Reader) readBlockCached(bh blockHandle, verifyChecksum, fillCache bool) (*block, util.Releaser, error) { + if r.cache != nil { + var ( + err error + ch *cache.Handle + ) + if fillCache { + ch = r.cache.Get(bh.offset, func() (size int, value cache.Value) { + var b *block + b, err = r.readBlock(bh, verifyChecksum) + if err != nil { + return 0, nil + } + return cap(b.data), b + }) + } else { + ch = r.cache.Get(bh.offset, nil) + } + if ch != nil { + b, ok := ch.Value().(*block) + if !ok { + ch.Release() + return nil, nil, errors.New("leveldb/table: inconsistent block type") + } + return b, ch, err + } else if err != nil { + return nil, nil, err + } + } + + b, err := r.readBlock(bh, verifyChecksum) + return b, b, err +} + +func (r *Reader) readFilterBlock(bh blockHandle) (*filterBlock, error) { + data, err := r.readRawBlock(bh, true) + if err != nil { + return nil, err + } + n := len(data) + if n < 5 { + return nil, r.newErrCorruptedBH(bh, "too short") + } + m := n - 5 + oOffset := int(binary.LittleEndian.Uint32(data[m:])) + if oOffset > m { + return nil, r.newErrCorruptedBH(bh, "invalid data-offsets offset") + } + b := &filterBlock{ + bpool: r.bpool, + data: data, + oOffset: oOffset, + baseLg: uint(data[n-1]), + filtersNum: (m - oOffset) / 4, + } + return b, nil +} + +func (r *Reader) readFilterBlockCached(bh blockHandle, fillCache bool) (*filterBlock, util.Releaser, error) { + if r.cache != nil { + var ( + err error + ch *cache.Handle + ) + if fillCache { + ch = r.cache.Get(bh.offset, func() (size int, value cache.Value) { + var b *filterBlock + b, err = r.readFilterBlock(bh) + if err != nil { + return 0, nil + } + return cap(b.data), b + }) + } else { + ch = r.cache.Get(bh.offset, nil) + } + if ch != nil { + b, ok := ch.Value().(*filterBlock) + if !ok { + ch.Release() + return nil, nil, errors.New("leveldb/table: inconsistent block type") + } + return b, ch, err + } else if err != nil { + return nil, nil, err + } + } + + b, err := r.readFilterBlock(bh) + return b, b, err +} + +func (r *Reader) getIndexBlock(fillCache bool) (b *block, rel util.Releaser, err error) { + if r.indexBlock == nil { + return r.readBlockCached(r.indexBH, true, fillCache) + } + return r.indexBlock, util.NoopReleaser{}, nil +} + +func (r *Reader) getFilterBlock(fillCache bool) (*filterBlock, util.Releaser, error) { + if r.filterBlock == nil { + return r.readFilterBlockCached(r.filterBH, fillCache) + } + return r.filterBlock, util.NoopReleaser{}, nil +} + +func (r *Reader) newBlockIter(b *block, bReleaser util.Releaser, slice *util.Range, inclLimit bool) *blockIter { + bi := &blockIter{ + tr: r, + block: b, + blockReleaser: bReleaser, + // Valid key should never be nil. + key: make([]byte, 0), + dir: dirSOI, + riStart: 0, + riLimit: b.restartsLen, + offsetStart: 0, + offsetRealStart: 0, + offsetLimit: b.restartsOffset, + } + if slice != nil { + if slice.Start != nil { + if bi.Seek(slice.Start) { + bi.riStart = b.restartIndex(bi.restartIndex, b.restartsLen, bi.prevOffset) + bi.offsetStart = b.restartOffset(bi.riStart) + bi.offsetRealStart = bi.prevOffset + } else { + bi.riStart = b.restartsLen + bi.offsetStart = b.restartsOffset + bi.offsetRealStart = b.restartsOffset + } + } + if slice.Limit != nil { + if bi.Seek(slice.Limit) && (!inclLimit || bi.Next()) { + bi.offsetLimit = bi.prevOffset + bi.riLimit = bi.restartIndex + 1 + } + } + bi.reset() + if bi.offsetStart > bi.offsetLimit { + bi.sErr(errors.New("leveldb/table: invalid slice range")) + } + } + return bi +} + +func (r *Reader) getDataIter(dataBH blockHandle, slice *util.Range, verifyChecksum, fillCache bool) iterator.Iterator { + b, rel, err := r.readBlockCached(dataBH, verifyChecksum, fillCache) + if err != nil { + return iterator.NewEmptyIterator(err) + } + return r.newBlockIter(b, rel, slice, false) +} + +func (r *Reader) getDataIterErr(dataBH blockHandle, slice *util.Range, verifyChecksum, fillCache bool) iterator.Iterator { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + return iterator.NewEmptyIterator(r.err) + } + + return r.getDataIter(dataBH, slice, verifyChecksum, fillCache) +} + +// NewIterator creates an iterator from the table. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// table. And a nil Range.Limit is treated as a key after all keys in +// the table. +// +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// +// The returned iterator is not safe for concurrent use and should be released +// after use. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (r *Reader) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + return iterator.NewEmptyIterator(r.err) + } + + fillCache := !ro.GetDontFillCache() + indexBlock, rel, err := r.getIndexBlock(fillCache) + if err != nil { + return iterator.NewEmptyIterator(err) + } + index := &indexIter{ + blockIter: r.newBlockIter(indexBlock, rel, slice, true), + tr: r, + slice: slice, + fillCache: !ro.GetDontFillCache(), + } + return iterator.NewIndexedIterator(index, opt.GetStrict(r.o, ro, opt.StrictReader)) +} + +func (r *Reader) find(key []byte, filtered bool, ro *opt.ReadOptions, noValue bool) (rkey, value []byte, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + indexBlock, rel, err := r.getIndexBlock(true) + if err != nil { + return + } + defer rel.Release() + + index := r.newBlockIter(indexBlock, nil, nil, true) + defer index.Release() + + if !index.Seek(key) { + if err = index.Error(); err == nil { + err = ErrNotFound + } + return + } + + dataBH, n := decodeBlockHandle(index.Value()) + if n == 0 { + r.err = r.newErrCorruptedBH(r.indexBH, "bad data block handle") + return nil, nil, r.err + } + + // The filter should only used for exact match. + if filtered && r.filter != nil { + filterBlock, frel, ferr := r.getFilterBlock(true) + if ferr == nil { + if !filterBlock.contains(r.filter, dataBH.offset, key) { + frel.Release() + return nil, nil, ErrNotFound + } + frel.Release() + } else if !errors.IsCorrupted(ferr) { + return nil, nil, ferr + } + } + + data := r.getDataIter(dataBH, nil, r.verifyChecksum, !ro.GetDontFillCache()) + if !data.Seek(key) { + data.Release() + if err = data.Error(); err != nil { + return + } + + // The nearest greater-than key is the first key of the next block. + if !index.Next() { + if err = index.Error(); err == nil { + err = ErrNotFound + } + return + } + + dataBH, n = decodeBlockHandle(index.Value()) + if n == 0 { + r.err = r.newErrCorruptedBH(r.indexBH, "bad data block handle") + return nil, nil, r.err + } + + data = r.getDataIter(dataBH, nil, r.verifyChecksum, !ro.GetDontFillCache()) + if !data.Next() { + data.Release() + if err = data.Error(); err == nil { + err = ErrNotFound + } + return + } + } + + // Key doesn't use block buffer, no need to copy the buffer. + rkey = data.Key() + if !noValue { + if r.bpool == nil { + value = data.Value() + } else { + // Value does use block buffer, and since the buffer will be + // recycled, it need to be copied. + value = append([]byte{}, data.Value()...) + } + } + data.Release() + return +} + +// Find finds key/value pair whose key is greater than or equal to the +// given key. It returns ErrNotFound if the table doesn't contain +// such pair. +// If filtered is true then the nearest 'block' will be checked against +// 'filter data' (if present) and will immediately return ErrNotFound if +// 'filter data' indicates that such pair doesn't exist. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) Find(key []byte, filtered bool, ro *opt.ReadOptions) (rkey, value []byte, err error) { + return r.find(key, filtered, ro, false) +} + +// FindKey finds key that is greater than or equal to the given key. +// It returns ErrNotFound if the table doesn't contain such key. +// If filtered is true then the nearest 'block' will be checked against +// 'filter data' (if present) and will immediately return ErrNotFound if +// 'filter data' indicates that such key doesn't exist. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) FindKey(key []byte, filtered bool, ro *opt.ReadOptions) (rkey []byte, err error) { + rkey, _, err = r.find(key, filtered, ro, true) + return +} + +// Get gets the value for the given key. It returns errors.ErrNotFound +// if the table does not contain the key. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + rkey, value, err := r.find(key, false, ro, false) + if err == nil && r.cmp.Compare(rkey, key) != 0 { + value = nil + err = ErrNotFound + } + return +} + +// OffsetOf returns approximate offset for the given key. +// +// It is safe to modify the contents of the argument after Get returns. +func (r *Reader) OffsetOf(key []byte) (offset int64, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + indexBlock, rel, err := r.readBlockCached(r.indexBH, true, true) + if err != nil { + return + } + defer rel.Release() + + index := r.newBlockIter(indexBlock, nil, nil, true) + defer index.Release() + if index.Seek(key) { + dataBH, n := decodeBlockHandle(index.Value()) + if n == 0 { + r.err = r.newErrCorruptedBH(r.indexBH, "bad data block handle") + return + } + offset = int64(dataBH.offset) + return + } + err = index.Error() + if err == nil { + offset = r.dataEnd + } + return +} + +// Release implements util.Releaser. +// It also close the file if it is an io.Closer. +func (r *Reader) Release() { + r.mu.Lock() + defer r.mu.Unlock() + + if closer, ok := r.reader.(io.Closer); ok { + closer.Close() + } + if r.indexBlock != nil { + r.indexBlock.Release() + r.indexBlock = nil + } + if r.filterBlock != nil { + r.filterBlock.Release() + r.filterBlock = nil + } + r.reader = nil + r.cache = nil + r.bpool = nil + r.err = ErrReaderReleased +} + +// NewReader creates a new initialized table reader for the file. +// The fi, cache and bpool is optional and can be nil. +// +// The returned table reader instance is safe for concurrent use. +func NewReader(f io.ReaderAt, size int64, fd storage.FileDesc, cache *cache.NamespaceGetter, bpool *util.BufferPool, o *opt.Options) (*Reader, error) { + if f == nil { + return nil, errors.New("leveldb/table: nil file") + } + + r := &Reader{ + fd: fd, + reader: f, + cache: cache, + bpool: bpool, + o: o, + cmp: o.GetComparer(), + verifyChecksum: o.GetStrict(opt.StrictBlockChecksum), + } + + if size < footerLen { + r.err = r.newErrCorrupted(0, size, "table", "too small") + return r, nil + } + + footerPos := size - footerLen + var footer [footerLen]byte + if _, err := r.reader.ReadAt(footer[:], footerPos); err != nil && err != io.EOF { + return nil, err + } + if string(footer[footerLen-len(magic):footerLen]) != magic { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad magic number") + return r, nil + } + + var n int + // Decode the metaindex block handle. + r.metaBH, n = decodeBlockHandle(footer[:]) + if n == 0 { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad metaindex block handle") + return r, nil + } + + // Decode the index block handle. + r.indexBH, n = decodeBlockHandle(footer[n:]) + if n == 0 { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad index block handle") + return r, nil + } + + // Read metaindex block. + metaBlock, err := r.readBlock(r.metaBH, true) + if err != nil { + if errors.IsCorrupted(err) { + r.err = err + return r, nil + } + return nil, err + } + + // Set data end. + r.dataEnd = int64(r.metaBH.offset) + + // Read metaindex. + metaIter := r.newBlockIter(metaBlock, nil, nil, true) + for metaIter.Next() { + key := string(metaIter.Key()) + if !strings.HasPrefix(key, "filter.") { + continue + } + fn := key[7:] + if f0 := o.GetFilter(); f0 != nil && f0.Name() == fn { + r.filter = f0 + } else { + for _, f0 := range o.GetAltFilters() { + if f0.Name() == fn { + r.filter = f0 + break + } + } + } + if r.filter != nil { + filterBH, n := decodeBlockHandle(metaIter.Value()) + if n == 0 { + continue + } + r.filterBH = filterBH + // Update data end. + r.dataEnd = int64(filterBH.offset) + break + } + } + metaIter.Release() + metaBlock.Release() + + // Cache index and filter block locally, since we don't have global cache. + if cache == nil { + r.indexBlock, err = r.readBlock(r.indexBH, true) + if err != nil { + if errors.IsCorrupted(err) { + r.err = err + return r, nil + } + return nil, err + } + if r.filter != nil { + r.filterBlock, err = r.readFilterBlock(r.filterBH) + if err != nil { + if !errors.IsCorrupted(err) { + return nil, err + } + + // Don't use filter then. + r.filter = nil + } + } + } + + return r, nil +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/table.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/table.go new file mode 100644 index 0000000..beacdc1 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/table.go @@ -0,0 +1,177 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package table allows read and write sorted key/value. +package table + +import ( + "encoding/binary" +) + +/* +Table: + +Table is consist of one or more data blocks, an optional filter block +a metaindex block, an index block and a table footer. Metaindex block +is a special block used to keep parameters of the table, such as filter +block name and its block handle. Index block is a special block used to +keep record of data blocks offset and length, index block use one as +restart interval. The key used by index block are the last key of preceding +block, shorter separator of adjacent blocks or shorter successor of the +last key of the last block. Filter block is an optional block contains +sequence of filter data generated by a filter generator. + +Table data structure: + + optional + / + +--------------+--------------+--------------+------+-------+-----------------+-------------+--------+ + | data block 1 | ... | data block n | filter block | metaindex block | index block | footer | + +--------------+--------------+--------------+--------------+-----------------+-------------+--------+ + + Each block followed by a 5-bytes trailer contains compression type and checksum. + +Table block trailer: + + +---------------------------+-------------------+ + | compression type (1-byte) | checksum (4-byte) | + +---------------------------+-------------------+ + + The checksum is a CRC-32 computed using Castagnoli's polynomial. Compression + type also included in the checksum. + +Table footer: + + +------------------- 40-bytes -------------------+ + / \ + +------------------------+--------------------+------+-----------------+ + | metaindex block handle / index block handle / ---- | magic (8-bytes) | + +------------------------+--------------------+------+-----------------+ + + The magic are first 64-bit of SHA-1 sum of "http://code.google.com/p/leveldb/". + +NOTE: All fixed-length integer are little-endian. +*/ + +/* +Block: + +Block is consist of one or more key/value entries and a block trailer. +Block entry shares key prefix with its preceding key until a restart +point reached. A block should contains at least one restart point. +First restart point are always zero. + +Block data structure: + + + restart point + restart point (depends on restart interval) + / / + +---------------+---------------+---------------+---------------+---------+ + | block entry 1 | block entry 2 | ... | block entry n | trailer | + +---------------+---------------+---------------+---------------+---------+ + +Key/value entry: + + +---- key len ----+ + / \ + +-------+---------+-----------+---------+--------------------+--------------+----------------+ + | shared (varint) | not shared (varint) | value len (varint) | key (varlen) | value (varlen) | + +-----------------+---------------------+--------------------+--------------+----------------+ + + Block entry shares key prefix with its preceding key: + Conditions: + restart_interval=2 + entry one : key=deck,value=v1 + entry two : key=dock,value=v2 + entry three: key=duck,value=v3 + The entries will be encoded as follow: + + + restart point (offset=0) + restart point (offset=16) + / / + +-----+-----+-----+----------+--------+-----+-----+-----+---------+--------+-----+-----+-----+----------+--------+ + | 0 | 4 | 2 | "deck" | "v1" | 1 | 3 | 2 | "ock" | "v2" | 0 | 4 | 2 | "duck" | "v3" | + +-----+-----+-----+----------+--------+-----+-----+-----+---------+--------+-----+-----+-----+----------+--------+ + \ / \ / \ / + +----------- entry one -----------+ +----------- entry two ----------+ +---------- entry three ----------+ + + The block trailer will contains two restart points: + + +------------+-----------+--------+ + | 0 | 16 | 2 | + +------------+-----------+---+----+ + \ / \ + +-- restart points --+ + restart points length + +Block trailer: + + +-- 4-bytes --+ + / \ + +-----------------+-----------------+-----------------+------------------------------+ + | restart point 1 | .... | restart point n | restart points len (4-bytes) | + +-----------------+-----------------+-----------------+------------------------------+ + + +NOTE: All fixed-length integer are little-endian. +*/ + +/* +Filter block: + +Filter block consist of one or more filter data and a filter block trailer. +The trailer contains filter data offsets, a trailer offset and a 1-byte base Lg. + +Filter block data structure: + + + offset 1 + offset 2 + offset n + trailer offset + / / / / + +---------------+---------------+---------------+---------+ + | filter data 1 | ... | filter data n | trailer | + +---------------+---------------+---------------+---------+ + +Filter block trailer: + + +- 4-bytes -+ + / \ + +---------------+---------------+---------------+-------------------------------+------------------+ + | data 1 offset | .... | data n offset | data-offsets offset (4-bytes) | base Lg (1-byte) | + +-------------- +---------------+---------------+-------------------------------+------------------+ + + +NOTE: All fixed-length integer are little-endian. +*/ + +const ( + blockTrailerLen = 5 + footerLen = 48 + + magic = "\x57\xfb\x80\x8b\x24\x75\x47\xdb" + + // The block type gives the per-block compression format. + // These constants are part of the file format and should not be changed. + blockTypeNoCompression = 0 + blockTypeSnappyCompression = 1 + + // Generate new filter every 2KB of data + filterBaseLg = 11 + filterBase = 1 << filterBaseLg +) + +type blockHandle struct { + offset, length uint64 +} + +func decodeBlockHandle(src []byte) (blockHandle, int) { + offset, n := binary.Uvarint(src) + length, m := binary.Uvarint(src[n:]) + if n == 0 || m == 0 { + return blockHandle{}, 0 + } + return blockHandle{offset, length}, n + m +} + +func encodeBlockHandle(dst []byte, b blockHandle) int { + n := binary.PutUvarint(dst, b.offset) + m := binary.PutUvarint(dst[n:], b.length) + return n + m +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go new file mode 100644 index 0000000..6465da6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go @@ -0,0 +1,11 @@ +package table + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestTable(t *testing.T) { + testutil.RunSuite(t, "Table Suite") +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/table_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/table_test.go new file mode 100644 index 0000000..232efcd --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/table_test.go @@ -0,0 +1,123 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type tableWrapper struct { + *Reader +} + +func (t tableWrapper) TestFind(key []byte) (rkey, rvalue []byte, err error) { + return t.Reader.Find(key, false, nil) +} + +func (t tableWrapper) TestGet(key []byte) (value []byte, err error) { + return t.Reader.Get(key, nil) +} + +func (t tableWrapper) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.Reader.NewIterator(slice, nil) +} + +var _ = testutil.Defer(func() { + Describe("Table", func() { + Describe("approximate offset test", func() { + var ( + buf = &bytes.Buffer{} + o = &opt.Options{ + BlockSize: 1024, + Compression: opt.NoCompression, + } + ) + + // Building the table. + tw := NewWriter(buf, o) + tw.Append([]byte("k01"), []byte("hello")) + tw.Append([]byte("k02"), []byte("hello2")) + tw.Append([]byte("k03"), bytes.Repeat([]byte{'x'}, 10000)) + tw.Append([]byte("k04"), bytes.Repeat([]byte{'x'}, 200000)) + tw.Append([]byte("k05"), bytes.Repeat([]byte{'x'}, 300000)) + tw.Append([]byte("k06"), []byte("hello3")) + tw.Append([]byte("k07"), bytes.Repeat([]byte{'x'}, 100000)) + err := tw.Close() + + It("Should be able to approximate offset of a key correctly", func() { + Expect(err).ShouldNot(HaveOccurred()) + + tr, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()), storage.FileDesc{}, nil, nil, o) + Expect(err).ShouldNot(HaveOccurred()) + CheckOffset := func(key string, expect, threshold int) { + offset, err := tr.OffsetOf([]byte(key)) + Expect(err).ShouldNot(HaveOccurred()) + Expect(offset).Should(BeNumerically("~", expect, threshold), "Offset of key %q", key) + } + + CheckOffset("k0", 0, 0) + CheckOffset("k01a", 0, 0) + CheckOffset("k02", 0, 0) + CheckOffset("k03", 0, 0) + CheckOffset("k04", 10000, 1000) + CheckOffset("k04a", 210000, 1000) + CheckOffset("k05", 210000, 1000) + CheckOffset("k06", 510000, 1000) + CheckOffset("k07", 510000, 1000) + CheckOffset("xyz", 610000, 2000) + }) + }) + + Describe("read test", func() { + Build := func(kv testutil.KeyValue) testutil.DB { + o := &opt.Options{ + BlockSize: 512, + BlockRestartInterval: 3, + } + buf := &bytes.Buffer{} + + // Building the table. + tw := NewWriter(buf, o) + kv.Iterate(func(i int, key, value []byte) { + tw.Append(key, value) + }) + tw.Close() + + // Opening the table. + tr, _ := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()), storage.FileDesc{}, nil, nil, o) + return tableWrapper{tr} + } + Test := func(kv *testutil.KeyValue, body func(r *Reader)) func() { + return func() { + db := Build(*kv) + if body != nil { + body(db.(tableWrapper).Reader) + } + testutil.KeyValueTesting(nil, *kv, db, nil, nil) + } + } + + testutil.AllKeyValueTesting(nil, Build, nil, nil) + Describe("with one key per block", Test(testutil.KeyValue_Generate(nil, 9, 1, 1, 10, 512, 512), func(r *Reader) { + It("should have correct blocks number", func() { + indexBlock, err := r.readBlock(r.indexBH, true) + Expect(err).To(BeNil()) + Expect(indexBlock.restartsLen).Should(Equal(9)) + }) + })) + }) + }) +}) diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/writer.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/writer.go new file mode 100644 index 0000000..b96b271 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/writer.go @@ -0,0 +1,375 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/golang/snappy" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func sharedPrefixLen(a, b []byte) int { + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for i < n && a[i] == b[i] { + i++ + } + return i +} + +type blockWriter struct { + restartInterval int + buf util.Buffer + nEntries int + prevKey []byte + restarts []uint32 + scratch []byte +} + +func (w *blockWriter) append(key, value []byte) { + nShared := 0 + if w.nEntries%w.restartInterval == 0 { + w.restarts = append(w.restarts, uint32(w.buf.Len())) + } else { + nShared = sharedPrefixLen(w.prevKey, key) + } + n := binary.PutUvarint(w.scratch[0:], uint64(nShared)) + n += binary.PutUvarint(w.scratch[n:], uint64(len(key)-nShared)) + n += binary.PutUvarint(w.scratch[n:], uint64(len(value))) + w.buf.Write(w.scratch[:n]) + w.buf.Write(key[nShared:]) + w.buf.Write(value) + w.prevKey = append(w.prevKey[:0], key...) + w.nEntries++ +} + +func (w *blockWriter) finish() { + // Write restarts entry. + if w.nEntries == 0 { + // Must have at least one restart entry. + w.restarts = append(w.restarts, 0) + } + w.restarts = append(w.restarts, uint32(len(w.restarts))) + for _, x := range w.restarts { + buf4 := w.buf.Alloc(4) + binary.LittleEndian.PutUint32(buf4, x) + } +} + +func (w *blockWriter) reset() { + w.buf.Reset() + w.nEntries = 0 + w.restarts = w.restarts[:0] +} + +func (w *blockWriter) bytesLen() int { + restartsLen := len(w.restarts) + if restartsLen == 0 { + restartsLen = 1 + } + return w.buf.Len() + 4*restartsLen + 4 +} + +type filterWriter struct { + generator filter.FilterGenerator + buf util.Buffer + nKeys int + offsets []uint32 +} + +func (w *filterWriter) add(key []byte) { + if w.generator == nil { + return + } + w.generator.Add(key) + w.nKeys++ +} + +func (w *filterWriter) flush(offset uint64) { + if w.generator == nil { + return + } + for x := int(offset / filterBase); x > len(w.offsets); { + w.generate() + } +} + +func (w *filterWriter) finish() { + if w.generator == nil { + return + } + // Generate last keys. + + if w.nKeys > 0 { + w.generate() + } + w.offsets = append(w.offsets, uint32(w.buf.Len())) + for _, x := range w.offsets { + buf4 := w.buf.Alloc(4) + binary.LittleEndian.PutUint32(buf4, x) + } + w.buf.WriteByte(filterBaseLg) +} + +func (w *filterWriter) generate() { + // Record offset. + w.offsets = append(w.offsets, uint32(w.buf.Len())) + // Generate filters. + if w.nKeys > 0 { + w.generator.Generate(&w.buf) + w.nKeys = 0 + } +} + +// Writer is a table writer. +type Writer struct { + writer io.Writer + err error + // Options + cmp comparer.Comparer + filter filter.Filter + compression opt.Compression + blockSize int + + dataBlock blockWriter + indexBlock blockWriter + filterBlock filterWriter + pendingBH blockHandle + offset uint64 + nEntries int + // Scratch allocated enough for 5 uvarint. Block writer should not use + // first 20-bytes since it will be used to encode block handle, which + // then passed to the block writer itself. + scratch [50]byte + comparerScratch []byte + compressionScratch []byte +} + +func (w *Writer) writeBlock(buf *util.Buffer, compression opt.Compression) (bh blockHandle, err error) { + // Compress the buffer if necessary. + var b []byte + if compression == opt.SnappyCompression { + // Allocate scratch enough for compression and block trailer. + if n := snappy.MaxEncodedLen(buf.Len()) + blockTrailerLen; len(w.compressionScratch) < n { + w.compressionScratch = make([]byte, n) + } + compressed := snappy.Encode(w.compressionScratch, buf.Bytes()) + n := len(compressed) + b = compressed[:n+blockTrailerLen] + b[n] = blockTypeSnappyCompression + } else { + tmp := buf.Alloc(blockTrailerLen) + tmp[0] = blockTypeNoCompression + b = buf.Bytes() + } + + // Calculate the checksum. + n := len(b) - 4 + checksum := util.NewCRC(b[:n]).Value() + binary.LittleEndian.PutUint32(b[n:], checksum) + + // Write the buffer to the file. + _, err = w.writer.Write(b) + if err != nil { + return + } + bh = blockHandle{w.offset, uint64(len(b) - blockTrailerLen)} + w.offset += uint64(len(b)) + return +} + +func (w *Writer) flushPendingBH(key []byte) { + if w.pendingBH.length == 0 { + return + } + var separator []byte + if len(key) == 0 { + separator = w.cmp.Successor(w.comparerScratch[:0], w.dataBlock.prevKey) + } else { + separator = w.cmp.Separator(w.comparerScratch[:0], w.dataBlock.prevKey, key) + } + if separator == nil { + separator = w.dataBlock.prevKey + } else { + w.comparerScratch = separator + } + n := encodeBlockHandle(w.scratch[:20], w.pendingBH) + // Append the block handle to the index block. + w.indexBlock.append(separator, w.scratch[:n]) + // Reset prev key of the data block. + w.dataBlock.prevKey = w.dataBlock.prevKey[:0] + // Clear pending block handle. + w.pendingBH = blockHandle{} +} + +func (w *Writer) finishBlock() error { + w.dataBlock.finish() + bh, err := w.writeBlock(&w.dataBlock.buf, w.compression) + if err != nil { + return err + } + w.pendingBH = bh + // Reset the data block. + w.dataBlock.reset() + // Flush the filter block. + w.filterBlock.flush(w.offset) + return nil +} + +// Append appends key/value pair to the table. The keys passed must +// be in increasing order. +// +// It is safe to modify the contents of the arguments after Append returns. +func (w *Writer) Append(key, value []byte) error { + if w.err != nil { + return w.err + } + if w.nEntries > 0 && w.cmp.Compare(w.dataBlock.prevKey, key) >= 0 { + w.err = fmt.Errorf("leveldb/table: Writer: keys are not in increasing order: %q, %q", w.dataBlock.prevKey, key) + return w.err + } + + w.flushPendingBH(key) + // Append key/value pair to the data block. + w.dataBlock.append(key, value) + // Add key to the filter block. + w.filterBlock.add(key) + + // Finish the data block if block size target reached. + if w.dataBlock.bytesLen() >= w.blockSize { + if err := w.finishBlock(); err != nil { + w.err = err + return w.err + } + } + w.nEntries++ + return nil +} + +// BlocksLen returns number of blocks written so far. +func (w *Writer) BlocksLen() int { + n := w.indexBlock.nEntries + if w.pendingBH.length > 0 { + // Includes the pending block. + n++ + } + return n +} + +// EntriesLen returns number of entries added so far. +func (w *Writer) EntriesLen() int { + return w.nEntries +} + +// BytesLen returns number of bytes written so far. +func (w *Writer) BytesLen() int { + return int(w.offset) +} + +// Close will finalize the table. Calling Append is not possible +// after Close, but calling BlocksLen, EntriesLen and BytesLen +// is still possible. +func (w *Writer) Close() error { + if w.err != nil { + return w.err + } + + // Write the last data block. Or empty data block if there + // aren't any data blocks at all. + if w.dataBlock.nEntries > 0 || w.nEntries == 0 { + if err := w.finishBlock(); err != nil { + w.err = err + return w.err + } + } + w.flushPendingBH(nil) + + // Write the filter block. + var filterBH blockHandle + w.filterBlock.finish() + if buf := &w.filterBlock.buf; buf.Len() > 0 { + filterBH, w.err = w.writeBlock(buf, opt.NoCompression) + if w.err != nil { + return w.err + } + } + + // Write the metaindex block. + if filterBH.length > 0 { + key := []byte("filter." + w.filter.Name()) + n := encodeBlockHandle(w.scratch[:20], filterBH) + w.dataBlock.append(key, w.scratch[:n]) + } + w.dataBlock.finish() + metaindexBH, err := w.writeBlock(&w.dataBlock.buf, w.compression) + if err != nil { + w.err = err + return w.err + } + + // Write the index block. + w.indexBlock.finish() + indexBH, err := w.writeBlock(&w.indexBlock.buf, w.compression) + if err != nil { + w.err = err + return w.err + } + + // Write the table footer. + footer := w.scratch[:footerLen] + for i := range footer { + footer[i] = 0 + } + n := encodeBlockHandle(footer, metaindexBH) + encodeBlockHandle(footer[n:], indexBH) + copy(footer[footerLen-len(magic):], magic) + if _, err := w.writer.Write(footer); err != nil { + w.err = err + return w.err + } + w.offset += footerLen + + w.err = errors.New("leveldb/table: writer is closed") + return nil +} + +// NewWriter creates a new initialized table writer for the file. +// +// Table writer is not safe for concurrent use. +func NewWriter(f io.Writer, o *opt.Options) *Writer { + w := &Writer{ + writer: f, + cmp: o.GetComparer(), + filter: o.GetFilter(), + compression: o.GetCompression(), + blockSize: o.GetBlockSize(), + comparerScratch: make([]byte, 0), + } + // data block + w.dataBlock.restartInterval = o.GetBlockRestartInterval() + // The first 20-bytes are used for encoding block handle. + w.dataBlock.scratch = w.scratch[20:] + // index block + w.indexBlock.restartInterval = 1 + w.indexBlock.scratch = w.scratch[20:] + // filter block + if w.filter != nil { + w.filterBlock.generator = w.filter.NewGenerator() + w.filterBlock.flush(0) + } + return w +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/table_test.go new file mode 100644 index 0000000..36a3c4e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table_test.go @@ -0,0 +1,159 @@ +// Copyright (c) 2019, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "encoding/binary" + "math/rand" + "reflect" + "testing" + + "github.com/onsi/gomega" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestGetOverlaps(t *testing.T) { + gomega.RegisterTestingT(t) + stor := testutil.NewStorage() + defer stor.Close() + s, err := newSession(stor, nil) + if err != nil { + t.Fatal(err) + } + + v := newVersion(s) + v.newStaging() + + tmp := make([]byte, 4) + mik := func(i uint64, typ keyType, ukey bool) []byte { + if i == 0 { + return nil + } + binary.BigEndian.PutUint32(tmp, uint32(i)) + if ukey { + key := make([]byte, 4) + copy(key, tmp) + return key + } + return []byte(makeInternalKey(nil, tmp, 0, typ)) + } + + rec := &sessionRecord{} + for i, f := range []struct { + min uint64 + max uint64 + level int + }{ + // Overlapped level 0 files + {1, 8, 0}, + {4, 5, 0}, + {6, 10, 0}, + // Non-overlapped level 1 files + {2, 3, 1}, + {8, 10, 1}, + {13, 13, 1}, + {20, 100, 1}, + } { + rec.addTable(f.level, int64(i), 1, mik(f.min, keyTypeVal, false), mik(f.max, keyTypeVal, false)) + } + vs := v.newStaging() + vs.commit(rec) + v = vs.finish(false) + + for i, x := range []struct { + min uint64 + max uint64 + level int + expected []int64 + }{ + // Level0 cases + {0, 0, 0, []int64{2, 1, 0}}, + {1, 0, 0, []int64{2, 1, 0}}, + {0, 10, 0, []int64{2, 1, 0}}, + {2, 7, 0, []int64{2, 1, 0}}, + + // Level1 cases + {1, 1, 1, nil}, + {0, 100, 1, []int64{3, 4, 5, 6}}, + {5, 0, 1, []int64{4, 5, 6}}, + {5, 4, 1, nil}, // invalid search space + {1, 13, 1, []int64{3, 4, 5}}, + {2, 13, 1, []int64{3, 4, 5}}, + {3, 13, 1, []int64{3, 4, 5}}, + {4, 13, 1, []int64{4, 5}}, + {4, 19, 1, []int64{4, 5}}, + {4, 20, 1, []int64{4, 5, 6}}, + {4, 100, 1, []int64{4, 5, 6}}, + {4, 105, 1, []int64{4, 5, 6}}, + } { + tf := v.levels[x.level] + res := tf.getOverlaps(nil, s.icmp, mik(x.min, keyTypeSeek, true), mik(x.max, keyTypeSeek, true), x.level == 0) + + var fnums []int64 + for _, f := range res { + fnums = append(fnums, f.fd.Num) + } + if !reflect.DeepEqual(x.expected, fnums) { + t.Errorf("case %d failed, expected %v, got %v", i, x.expected, fnums) + } + } +} + +func BenchmarkGetOverlapLevel0(b *testing.B) { + benchmarkGetOverlap(b, 0, 500000) +} + +func BenchmarkGetOverlapNonLevel0(b *testing.B) { + benchmarkGetOverlap(b, 1, 500000) +} + +func benchmarkGetOverlap(b *testing.B, level int, size int) { + stor := storage.NewMemStorage() + defer stor.Close() + s, err := newSession(stor, nil) + if err != nil { + b.Fatal(err) + } + + v := newVersion(s) + v.newStaging() + + tmp := make([]byte, 4) + mik := func(i uint64, typ keyType, ukey bool) []byte { + if i == 0 { + return nil + } + binary.BigEndian.PutUint32(tmp, uint32(i)) + if ukey { + key := make([]byte, 4) + copy(key, tmp) + return key + } + return []byte(makeInternalKey(nil, tmp, 0, typ)) + } + + rec := &sessionRecord{} + for i := 1; i <= size; i++ { + min := mik(uint64(2*i), keyTypeVal, false) + max := mik(uint64(2*i+1), keyTypeVal, false) + rec.addTable(level, int64(i), 1, min, max) + } + vs := v.newStaging() + vs.commit(rec) + v = vs.finish(false) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + files := v.levels[level] + start := rand.Intn(size) + end := rand.Intn(size-start) + start + files.getOverlaps(nil, s.icmp, mik(uint64(2*start), keyTypeVal, true), mik(uint64(2*end), keyTypeVal, true), level == 0) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/db.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/db.go new file mode 100644 index 0000000..ec3f177 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/db.go @@ -0,0 +1,222 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type DB interface{} + +type Put interface { + TestPut(key []byte, value []byte) error +} + +type Delete interface { + TestDelete(key []byte) error +} + +type Find interface { + TestFind(key []byte) (rkey, rvalue []byte, err error) +} + +type Get interface { + TestGet(key []byte) (value []byte, err error) +} + +type Has interface { + TestHas(key []byte) (ret bool, err error) +} + +type NewIterator interface { + TestNewIterator(slice *util.Range) iterator.Iterator +} + +type DBAct int + +func (a DBAct) String() string { + switch a { + case DBNone: + return "none" + case DBPut: + return "put" + case DBOverwrite: + return "overwrite" + case DBDelete: + return "delete" + case DBDeleteNA: + return "delete_na" + } + return "unknown" +} + +const ( + DBNone DBAct = iota + DBPut + DBOverwrite + DBDelete + DBDeleteNA +) + +type DBTesting struct { + Rand *rand.Rand + DB interface { + Get + Put + Delete + } + PostFn func(t *DBTesting) + Deleted, Present KeyValue + Act, LastAct DBAct + ActKey, LastActKey []byte +} + +func (t *DBTesting) post() { + if t.PostFn != nil { + t.PostFn(t) + } +} + +func (t *DBTesting) setAct(act DBAct, key []byte) { + t.LastAct, t.Act = t.Act, act + t.LastActKey, t.ActKey = t.ActKey, key +} + +func (t *DBTesting) text() string { + return fmt.Sprintf("last action was <%v> %q, <%v> %q", t.LastAct, t.LastActKey, t.Act, t.ActKey) +} + +func (t *DBTesting) Text() string { + return "DBTesting " + t.text() +} + +func (t *DBTesting) TestPresentKV(key, value []byte) { + rvalue, err := t.DB.TestGet(key) + Expect(err).ShouldNot(HaveOccurred(), "Get on key %q, %s", key, t.text()) + Expect(rvalue).Should(Equal(value), "Value for key %q, %s", key, t.text()) +} + +func (t *DBTesting) TestAllPresent() { + t.Present.IterateShuffled(t.Rand, func(i int, key, value []byte) { + t.TestPresentKV(key, value) + }) +} + +func (t *DBTesting) TestDeletedKey(key []byte) { + _, err := t.DB.TestGet(key) + Expect(err).Should(Equal(errors.ErrNotFound), "Get on deleted key %q, %s", key, t.text()) +} + +func (t *DBTesting) TestAllDeleted() { + t.Deleted.IterateShuffled(t.Rand, func(i int, key, value []byte) { + t.TestDeletedKey(key) + }) +} + +func (t *DBTesting) TestAll() { + dn := t.Deleted.Len() + pn := t.Present.Len() + ShuffledIndex(t.Rand, dn+pn, 1, func(i int) { + if i >= dn { + key, value := t.Present.Index(i - dn) + t.TestPresentKV(key, value) + } else { + t.TestDeletedKey(t.Deleted.KeyAt(i)) + } + }) +} + +func (t *DBTesting) Put(key, value []byte) { + if new := t.Present.PutU(key, value); new { + t.setAct(DBPut, key) + } else { + t.setAct(DBOverwrite, key) + } + t.Deleted.Delete(key) + err := t.DB.TestPut(key, value) + Expect(err).ShouldNot(HaveOccurred(), t.Text()) + t.TestPresentKV(key, value) + t.post() +} + +func (t *DBTesting) PutRandom() bool { + if t.Deleted.Len() > 0 { + i := t.Rand.Intn(t.Deleted.Len()) + key, value := t.Deleted.Index(i) + t.Put(key, value) + return true + } + return false +} + +func (t *DBTesting) Delete(key []byte) { + if exist, value := t.Present.Delete(key); exist { + t.setAct(DBDelete, key) + t.Deleted.PutU(key, value) + } else { + t.setAct(DBDeleteNA, key) + } + err := t.DB.TestDelete(key) + Expect(err).ShouldNot(HaveOccurred(), t.Text()) + t.TestDeletedKey(key) + t.post() +} + +func (t *DBTesting) DeleteRandom() bool { + if t.Present.Len() > 0 { + i := t.Rand.Intn(t.Present.Len()) + t.Delete(t.Present.KeyAt(i)) + return true + } + return false +} + +func (t *DBTesting) RandomAct(round int) { + for i := 0; i < round; i++ { + if t.Rand.Int()%2 == 0 { + t.PutRandom() + } else { + t.DeleteRandom() + } + } +} + +func DoDBTesting(t *DBTesting) { + if t.Rand == nil { + t.Rand = NewRand() + } + + t.DeleteRandom() + t.PutRandom() + t.DeleteRandom() + t.DeleteRandom() + for i := t.Deleted.Len() / 2; i >= 0; i-- { + t.PutRandom() + } + t.RandomAct((t.Deleted.Len() + t.Present.Len()) * 10) + + // Additional iterator testing + if db, ok := t.DB.(NewIterator); ok { + iter := db.TestNewIterator(nil) + Expect(iter.Error()).NotTo(HaveOccurred()) + + it := IteratorTesting{ + KeyValue: t.Present, + Iter: iter, + } + + DoIteratorTesting(&it) + iter.Release() + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go new file mode 100644 index 0000000..82f3d0e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go @@ -0,0 +1,21 @@ +package testutil + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func RunSuite(t GinkgoTestingT, name string) { + RunDefer() + + SynchronizedBeforeSuite(func() []byte { + RunDefer("setup") + return nil + }, func(data []byte) {}) + SynchronizedAfterSuite(func() { + RunDefer("teardown") + }, func() {}) + + RegisterFailHandler(Fail) + RunSpecs(t, name) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/iter.go new file mode 100644 index 0000000..df6d9db --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/iter.go @@ -0,0 +1,327 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" +) + +type IterAct int + +func (a IterAct) String() string { + switch a { + case IterNone: + return "none" + case IterFirst: + return "first" + case IterLast: + return "last" + case IterPrev: + return "prev" + case IterNext: + return "next" + case IterSeek: + return "seek" + case IterSOI: + return "soi" + case IterEOI: + return "eoi" + } + return "unknown" +} + +const ( + IterNone IterAct = iota + IterFirst + IterLast + IterPrev + IterNext + IterSeek + IterSOI + IterEOI +) + +type IteratorTesting struct { + KeyValue + Iter iterator.Iterator + Rand *rand.Rand + PostFn func(t *IteratorTesting) + Pos int + Act, LastAct IterAct + + once bool +} + +func (t *IteratorTesting) init() { + if !t.once { + t.Pos = -1 + t.once = true + } +} + +func (t *IteratorTesting) post() { + if t.PostFn != nil { + t.PostFn(t) + } +} + +func (t *IteratorTesting) setAct(act IterAct) { + t.LastAct, t.Act = t.Act, act +} + +func (t *IteratorTesting) text() string { + return fmt.Sprintf("at pos %d and last action was <%v> -> <%v>", t.Pos, t.LastAct, t.Act) +} + +func (t *IteratorTesting) Text() string { + return "IteratorTesting is " + t.text() +} + +func (t *IteratorTesting) IsFirst() bool { + t.init() + return t.Len() > 0 && t.Pos == 0 +} + +func (t *IteratorTesting) IsLast() bool { + t.init() + return t.Len() > 0 && t.Pos == t.Len()-1 +} + +func (t *IteratorTesting) TestKV() { + t.init() + key, value := t.Index(t.Pos) + Expect(t.Iter.Key()).NotTo(BeNil()) + Expect(t.Iter.Key()).Should(Equal(key), "Key is invalid, %s", t.text()) + Expect(t.Iter.Value()).Should(Equal(value), "Value for key %q, %s", key, t.text()) +} + +func (t *IteratorTesting) First() { + t.init() + t.setAct(IterFirst) + + ok := t.Iter.First() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Len() > 0 { + t.Pos = 0 + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = -1 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Last() { + t.init() + t.setAct(IterLast) + + ok := t.Iter.Last() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Len() > 0 { + t.Pos = t.Len() - 1 + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = 0 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Next() { + t.init() + t.setAct(IterNext) + + ok := t.Iter.Next() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Pos < t.Len()-1 { + t.Pos++ + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = t.Len() + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Prev() { + t.init() + t.setAct(IterPrev) + + ok := t.Iter.Prev() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Pos > 0 { + t.Pos-- + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = -1 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Seek(i int) { + t.init() + t.setAct(IterSeek) + + key, _ := t.Index(i) + oldKey, _ := t.IndexOrNil(t.Pos) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q, to pos %d, %s", oldKey, key, i, t.text())) + + t.Pos = i + t.TestKV() + t.post() +} + +func (t *IteratorTesting) SeekInexact(i int) { + t.init() + t.setAct(IterSeek) + var key0 []byte + key1, _ := t.Index(i) + if i > 0 { + key0, _ = t.Index(i - 1) + } + key := BytesSeparator(key0, key1) + oldKey, _ := t.IndexOrNil(t.Pos) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q (%q), to pos %d, %s", oldKey, key, key1, i, t.text())) + + t.Pos = i + t.TestKV() + t.post() +} + +func (t *IteratorTesting) SeekKey(key []byte) { + t.init() + t.setAct(IterSeek) + oldKey, _ := t.IndexOrNil(t.Pos) + i := t.Search(key) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if i < t.Len() { + key_, _ := t.Index(i) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q (%q), to pos %d, %s", oldKey, key, key_, i, t.text())) + t.Pos = i + t.TestKV() + } else { + Expect(ok).ShouldNot(BeTrue(), fmt.Sprintf("Seek from key %q to %q, %s", oldKey, key, t.text())) + } + + t.Pos = i + t.post() +} + +func (t *IteratorTesting) SOI() { + t.init() + t.setAct(IterSOI) + Expect(t.Pos).Should(BeNumerically("<=", 0), t.Text()) + for i := 0; i < 3; i++ { + t.Prev() + } + t.post() +} + +func (t *IteratorTesting) EOI() { + t.init() + t.setAct(IterEOI) + Expect(t.Pos).Should(BeNumerically(">=", t.Len()-1), t.Text()) + for i := 0; i < 3; i++ { + t.Next() + } + t.post() +} + +func (t *IteratorTesting) WalkPrev(fn func(t *IteratorTesting)) { + t.init() + for old := t.Pos; t.Pos > 0; old = t.Pos { + fn(t) + Expect(t.Pos).Should(BeNumerically("<", old), t.Text()) + } +} + +func (t *IteratorTesting) WalkNext(fn func(t *IteratorTesting)) { + t.init() + for old := t.Pos; t.Pos < t.Len()-1; old = t.Pos { + fn(t) + Expect(t.Pos).Should(BeNumerically(">", old), t.Text()) + } +} + +func (t *IteratorTesting) PrevAll() { + t.WalkPrev(func(t *IteratorTesting) { + t.Prev() + }) +} + +func (t *IteratorTesting) NextAll() { + t.WalkNext(func(t *IteratorTesting) { + t.Next() + }) +} + +func DoIteratorTesting(t *IteratorTesting) { + if t.Rand == nil { + t.Rand = NewRand() + } + t.SOI() + t.NextAll() + t.First() + t.SOI() + t.NextAll() + t.EOI() + t.PrevAll() + t.Last() + t.EOI() + t.PrevAll() + t.SOI() + + t.NextAll() + t.PrevAll() + t.NextAll() + t.Last() + t.PrevAll() + t.First() + t.NextAll() + t.EOI() + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.Seek(i) + }) + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.SeekInexact(i) + }) + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.Seek(i) + if i%2 != 0 { + t.PrevAll() + t.SOI() + } else { + t.NextAll() + t.EOI() + } + }) + + for _, key := range []string{"", "foo", "bar", "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"} { + t.SeekKey([]byte(key)) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kv.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kv.go new file mode 100644 index 0000000..608cbf3 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kv.go @@ -0,0 +1,352 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + "sort" + "strings" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +type KeyValueEntry struct { + key, value []byte +} + +type KeyValue struct { + entries []KeyValueEntry + nbytes int +} + +func (kv *KeyValue) Put(key, value []byte) { + if n := len(kv.entries); n > 0 && cmp.Compare(kv.entries[n-1].key, key) >= 0 { + panic(fmt.Sprintf("Put: keys are not in increasing order: %q, %q", kv.entries[n-1].key, key)) + } + kv.entries = append(kv.entries, KeyValueEntry{key, value}) + kv.nbytes += len(key) + len(value) +} + +func (kv *KeyValue) PutString(key, value string) { + kv.Put([]byte(key), []byte(value)) +} + +func (kv *KeyValue) PutU(key, value []byte) bool { + if i, exist := kv.Get(key); !exist { + if i < kv.Len() { + kv.entries = append(kv.entries[:i+1], kv.entries[i:]...) + kv.entries[i] = KeyValueEntry{key, value} + } else { + kv.entries = append(kv.entries, KeyValueEntry{key, value}) + } + kv.nbytes += len(key) + len(value) + return true + } else { + kv.nbytes += len(value) - len(kv.ValueAt(i)) + kv.entries[i].value = value + } + return false +} + +func (kv *KeyValue) PutUString(key, value string) bool { + return kv.PutU([]byte(key), []byte(value)) +} + +func (kv *KeyValue) Delete(key []byte) (exist bool, value []byte) { + i, exist := kv.Get(key) + if exist { + value = kv.entries[i].value + kv.DeleteIndex(i) + } + return +} + +func (kv *KeyValue) DeleteIndex(i int) bool { + if i < kv.Len() { + kv.nbytes -= len(kv.KeyAt(i)) + len(kv.ValueAt(i)) + kv.entries = append(kv.entries[:i], kv.entries[i+1:]...) + return true + } + return false +} + +func (kv KeyValue) Len() int { + return len(kv.entries) +} + +func (kv *KeyValue) Size() int { + return kv.nbytes +} + +func (kv KeyValue) KeyAt(i int) []byte { + return kv.entries[i].key +} + +func (kv KeyValue) ValueAt(i int) []byte { + return kv.entries[i].value +} + +func (kv KeyValue) Index(i int) (key, value []byte) { + if i < 0 || i >= len(kv.entries) { + panic(fmt.Sprintf("Index #%d: out of range", i)) + } + return kv.entries[i].key, kv.entries[i].value +} + +func (kv KeyValue) IndexInexact(i int) (key_, key, value []byte) { + key, value = kv.Index(i) + var key0 []byte + var key1 = kv.KeyAt(i) + if i > 0 { + key0 = kv.KeyAt(i - 1) + } + key_ = BytesSeparator(key0, key1) + return +} + +func (kv KeyValue) IndexOrNil(i int) (key, value []byte) { + if i >= 0 && i < len(kv.entries) { + return kv.entries[i].key, kv.entries[i].value + } + return nil, nil +} + +func (kv KeyValue) IndexString(i int) (key, value string) { + key_, _value := kv.Index(i) + return string(key_), string(_value) +} + +func (kv KeyValue) Search(key []byte) int { + return sort.Search(kv.Len(), func(i int) bool { + return cmp.Compare(kv.KeyAt(i), key) >= 0 + }) +} + +func (kv KeyValue) SearchString(key string) int { + return kv.Search([]byte(key)) +} + +func (kv KeyValue) Get(key []byte) (i int, exist bool) { + i = kv.Search(key) + if i < kv.Len() && cmp.Compare(kv.KeyAt(i), key) == 0 { + exist = true + } + return +} + +func (kv KeyValue) GetString(key string) (i int, exist bool) { + return kv.Get([]byte(key)) +} + +func (kv KeyValue) Iterate(fn func(i int, key, value []byte)) { + for i, x := range kv.entries { + fn(i, x.key, x.value) + } +} + +func (kv KeyValue) IterateString(fn func(i int, key, value string)) { + kv.Iterate(func(i int, key, value []byte) { + fn(i, string(key), string(value)) + }) +} + +func (kv KeyValue) IterateShuffled(rnd *rand.Rand, fn func(i int, key, value []byte)) { + ShuffledIndex(rnd, kv.Len(), 1, func(i int) { + fn(i, kv.entries[i].key, kv.entries[i].value) + }) +} + +func (kv KeyValue) IterateShuffledString(rnd *rand.Rand, fn func(i int, key, value string)) { + kv.IterateShuffled(rnd, func(i int, key, value []byte) { + fn(i, string(key), string(value)) + }) +} + +func (kv KeyValue) IterateInexact(fn func(i int, key_, key, value []byte)) { + for i := range kv.entries { + key_, key, value := kv.IndexInexact(i) + fn(i, key_, key, value) + } +} + +func (kv KeyValue) IterateInexactString(fn func(i int, key_, key, value string)) { + kv.IterateInexact(func(i int, key_, key, value []byte) { + fn(i, string(key_), string(key), string(value)) + }) +} + +func (kv KeyValue) Clone() KeyValue { + return KeyValue{append([]KeyValueEntry{}, kv.entries...), kv.nbytes} +} + +func (kv KeyValue) Slice(start, limit int) KeyValue { + if start < 0 || limit > kv.Len() { + panic(fmt.Sprintf("Slice %d .. %d: out of range", start, limit)) + } else if limit < start { + panic(fmt.Sprintf("Slice %d .. %d: invalid range", start, limit)) + } + return KeyValue{append([]KeyValueEntry{}, kv.entries[start:limit]...), kv.nbytes} +} + +func (kv KeyValue) SliceKey(start, limit []byte) KeyValue { + start_ := 0 + limit_ := kv.Len() + if start != nil { + start_ = kv.Search(start) + } + if limit != nil { + limit_ = kv.Search(limit) + } + return kv.Slice(start_, limit_) +} + +func (kv KeyValue) SliceKeyString(start, limit string) KeyValue { + return kv.SliceKey([]byte(start), []byte(limit)) +} + +func (kv KeyValue) SliceRange(r *util.Range) KeyValue { + if r != nil { + return kv.SliceKey(r.Start, r.Limit) + } + return kv.Clone() +} + +func (kv KeyValue) Range(start, limit int) (r util.Range) { + if kv.Len() > 0 { + if start == kv.Len() { + r.Start = BytesAfter(kv.KeyAt(start - 1)) + } else { + r.Start = kv.KeyAt(start) + } + } + if limit < kv.Len() { + r.Limit = kv.KeyAt(limit) + } + return +} + +func KeyValue_EmptyKey() *KeyValue { + kv := &KeyValue{} + kv.PutString("", "v") + return kv +} + +func KeyValue_EmptyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("abc", "") + kv.PutString("abcd", "") + return kv +} + +func KeyValue_OneKeyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("abc", "v") + return kv +} + +func KeyValue_BigValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("big1", strings.Repeat("1", 200000)) + return kv +} + +func KeyValue_SpecialKey() *KeyValue { + kv := &KeyValue{} + kv.PutString("\xff\xff", "v3") + return kv +} + +func KeyValue_MultipleKeyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("a", "v") + kv.PutString("aa", "v1") + kv.PutString("aaa", "v2") + kv.PutString("aaacccccccccc", "v2") + kv.PutString("aaaccccccccccd", "v3") + kv.PutString("aaaccccccccccf", "v4") + kv.PutString("aaaccccccccccfg", "v5") + kv.PutString("ab", "v6") + kv.PutString("abc", "v7") + kv.PutString("abcd", "v8") + kv.PutString("accccccccccccccc", "v9") + kv.PutString("b", "v10") + kv.PutString("bb", "v11") + kv.PutString("bc", "v12") + kv.PutString("c", "v13") + kv.PutString("c1", "v13") + kv.PutString("czzzzzzzzzzzzzz", "v14") + kv.PutString("fffffffffffffff", "v15") + kv.PutString("g11", "v15") + kv.PutString("g111", "v15") + kv.PutString("g111\xff", "v15") + kv.PutString("zz", "v16") + kv.PutString("zzzzzzz", "v16") + kv.PutString("zzzzzzzzzzzzzzzz", "v16") + return kv +} + +var keymap = []byte("012345678ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy") + +func KeyValue_Generate(rnd *rand.Rand, n, incr, minlen, maxlen, vminlen, vmaxlen int) *KeyValue { + if rnd == nil { + rnd = NewRand() + } + if maxlen < minlen { + panic("max len should >= min len") + } + + rrand := func(min, max int) int { + if min == max { + return max + } + return rnd.Intn(max-min) + min + } + + kv := &KeyValue{} + endC := byte(len(keymap) - incr) + gen := make([]byte, 0, maxlen) + for i := 0; i < n; i++ { + m := rrand(minlen, maxlen) + last := gen + retry: + gen = last[:m] + if k := len(last); m > k { + for j := k; j < m; j++ { + gen[j] = 0 + } + } else { + for j := m - 1; j >= 0; j-- { + c := last[j] + if c == endC { + continue + } + gen[j] = c + byte(incr) + for j++; j < m; j++ { + gen[j] = 0 + } + goto ok + } + if m < maxlen { + m++ + goto retry + } + panic(fmt.Sprintf("only able to generate %d keys out of %d keys, try increasing max len", kv.Len(), n)) + ok: + } + key := make([]byte, m) + for j := 0; j < m; j++ { + key[j] = keymap[gen[j]] + } + value := make([]byte, rrand(vminlen, vmaxlen)) + for n := copy(value, []byte(fmt.Sprintf("v%d", i))); n < len(value); n++ { + value[n] = 'x' + } + kv.Put(key, value) + } + return kv +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go new file mode 100644 index 0000000..f7563dc --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go @@ -0,0 +1,212 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func TestFind(db Find, kv KeyValue) { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, value := kv.IndexInexact(i) + + // Using exact key. + rkey, rvalue, err := db.TestFind(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for exact key %q", key) + Expect(rkey).Should(Equal(key), "Key") + Expect(rvalue).Should(Equal(value), "Value for exact key %q", key) + + // Using inexact key. + rkey, rvalue, err = db.TestFind(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for inexact key %q (%q)", key_, key) + Expect(rkey).Should(Equal(key), "Key for inexact key %q (%q)", key_, key) + Expect(rvalue).Should(Equal(value), "Value for inexact key %q (%q)", key_, key) + }) +} + +func TestFindAfterLast(db Find, kv KeyValue) { + var key []byte + if kv.Len() > 0 { + key_, _ := kv.Index(kv.Len() - 1) + key = BytesAfter(key_) + } + rkey, _, err := db.TestFind(key) + Expect(err).Should(HaveOccurred(), "Find for key %q yield key %q", key, rkey) + Expect(err).Should(Equal(errors.ErrNotFound)) +} + +func TestGet(db Get, kv KeyValue) { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, value := kv.IndexInexact(i) + + // Using exact key. + rvalue, err := db.TestGet(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key) + Expect(rvalue).Should(Equal(value), "Value for key %q", key) + + // Using inexact key. + if len(key_) > 0 { + _, err = db.TestGet(key_) + Expect(err).Should(HaveOccurred(), "Error for key %q", key_) + Expect(err).Should(Equal(errors.ErrNotFound)) + } + }) +} + +func TestHas(db Has, kv KeyValue) { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, _ := kv.IndexInexact(i) + + // Using exact key. + ret, err := db.TestHas(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key) + Expect(ret).Should(BeTrue(), "False for key %q", key) + + // Using inexact key. + if len(key_) > 0 { + ret, err = db.TestHas(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key_) + Expect(ret).ShouldNot(BeTrue(), "True for key %q", key) + } + }) +} + +func TestIter(db NewIterator, r *util.Range, kv KeyValue) { + iter := db.TestNewIterator(r) + Expect(iter.Error()).ShouldNot(HaveOccurred()) + + t := IteratorTesting{ + KeyValue: kv, + Iter: iter, + } + + DoIteratorTesting(&t) + iter.Release() +} + +func KeyValueTesting(rnd *rand.Rand, kv KeyValue, p DB, setup func(KeyValue) DB, teardown func(DB)) { + if rnd == nil { + rnd = NewRand() + } + + if p == nil { + BeforeEach(func() { + p = setup(kv) + }) + if teardown != nil { + AfterEach(func() { + teardown(p) + }) + } + } + + It("Should find all keys with Find", func() { + if db, ok := p.(Find); ok { + TestFind(db, kv) + } + }) + + It("Should return error if Find on key after the last", func() { + if db, ok := p.(Find); ok { + TestFindAfterLast(db, kv) + } + }) + + It("Should only find exact key with Get", func() { + if db, ok := p.(Get); ok { + TestGet(db, kv) + } + }) + + It("Should only find present key with Has", func() { + if db, ok := p.(Has); ok { + TestHas(db, kv) + } + }) + + It("Should iterates and seeks correctly", func(done Done) { + if db, ok := p.(NewIterator); ok { + TestIter(db, nil, kv.Clone()) + } + done <- true + }, 30.0) + + It("Should iterates and seeks slice correctly", func(done Done) { + if db, ok := p.(NewIterator); ok { + RandomIndex(rnd, kv.Len(), Min(kv.Len(), 50), func(i int) { + type slice struct { + r *util.Range + start, limit int + } + + key_, _, _ := kv.IndexInexact(i) + for _, x := range []slice{ + {&util.Range{Start: key_, Limit: nil}, i, kv.Len()}, + {&util.Range{Start: nil, Limit: key_}, 0, i}, + } { + By(fmt.Sprintf("Random index of %d .. %d", x.start, x.limit), func() { + TestIter(db, x.r, kv.Slice(x.start, x.limit)) + }) + } + }) + } + done <- true + }, 200.0) + + It("Should iterates and seeks slice correctly", func(done Done) { + if db, ok := p.(NewIterator); ok { + RandomRange(rnd, kv.Len(), Min(kv.Len(), 50), func(start, limit int) { + By(fmt.Sprintf("Random range of %d .. %d", start, limit), func() { + r := kv.Range(start, limit) + TestIter(db, &r, kv.Slice(start, limit)) + }) + }) + } + done <- true + }, 200.0) +} + +func AllKeyValueTesting(rnd *rand.Rand, body, setup func(KeyValue) DB, teardown func(DB)) { + Test := func(kv *KeyValue) func() { + return func() { + var p DB + if setup != nil { + Defer("setup", func() { + p = setup(*kv) + }) + } + if teardown != nil { + Defer("teardown", func() { + teardown(p) + }) + } + if body != nil { + p = body(*kv) + } + KeyValueTesting(rnd, *kv, p, func(KeyValue) DB { + return p + }, nil) + } + } + + Describe("with no key/value (empty)", Test(&KeyValue{})) + Describe("with empty key", Test(KeyValue_EmptyKey())) + Describe("with empty value", Test(KeyValue_EmptyValue())) + Describe("with one key/value", Test(KeyValue_OneKeyValue())) + Describe("with big value", Test(KeyValue_BigValue())) + Describe("with special key", Test(KeyValue_SpecialKey())) + Describe("with multiple key/value", Test(KeyValue_MultipleKeyValue())) + Describe("with generated key/value 2-incr", Test(KeyValue_Generate(nil, 120, 2, 1, 50, 10, 120))) + Describe("with generated key/value 3-incr", Test(KeyValue_Generate(nil, 120, 3, 1, 50, 10, 120))) +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/storage.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/storage.go new file mode 100644 index 0000000..e322d04 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/storage.go @@ -0,0 +1,696 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/storage" +) + +var ( + storageMu sync.Mutex + storageUseFS = true + storageKeepFS = false + storageNum int +) + +type StorageMode int + +const ( + ModeOpen StorageMode = 1 << iota + ModeCreate + ModeRemove + ModeRename + ModeRead + ModeWrite + ModeSync + ModeClose +) + +const ( + modeOpen = iota + modeCreate + modeRemove + modeRename + modeRead + modeWrite + modeSync + modeClose + + modeCount +) + +const ( + typeManifest = iota + typeJournal + typeTable + typeTemp + + typeCount +) + +const flattenCount = modeCount * typeCount + +func flattenType(m StorageMode, t storage.FileType) int { + var x int + switch m { + case ModeOpen: + x = modeOpen + case ModeCreate: + x = modeCreate + case ModeRemove: + x = modeRemove + case ModeRename: + x = modeRename + case ModeRead: + x = modeRead + case ModeWrite: + x = modeWrite + case ModeSync: + x = modeSync + case ModeClose: + x = modeClose + default: + panic("invalid storage mode") + } + x *= typeCount + switch t { + case storage.TypeManifest: + return x + typeManifest + case storage.TypeJournal: + return x + typeJournal + case storage.TypeTable: + return x + typeTable + case storage.TypeTemp: + return x + typeTemp + default: + panic("invalid file type") + } +} + +func listFlattenType(m StorageMode, t storage.FileType) []int { + ret := make([]int, 0, flattenCount) + add := func(x int) { + x *= typeCount + switch { + case t&storage.TypeManifest != 0: + ret = append(ret, x+typeManifest) + case t&storage.TypeJournal != 0: + ret = append(ret, x+typeJournal) + case t&storage.TypeTable != 0: + ret = append(ret, x+typeTable) + case t&storage.TypeTemp != 0: + ret = append(ret, x+typeTemp) + } + } + switch { + case m&ModeOpen != 0: + add(modeOpen) + case m&ModeCreate != 0: + add(modeCreate) + case m&ModeRemove != 0: + add(modeRemove) + case m&ModeRename != 0: + add(modeRename) + case m&ModeRead != 0: + add(modeRead) + case m&ModeWrite != 0: + add(modeWrite) + case m&ModeSync != 0: + add(modeSync) + case m&ModeClose != 0: + add(modeClose) + } + return ret +} + +func packFile(fd storage.FileDesc) uint64 { + if fd.Num>>(63-typeCount) != 0 { + panic("overflow") + } + return uint64(fd.Num<> typeCount), + } +} + +type emulatedError struct { + err error +} + +func (err emulatedError) Error() string { + return fmt.Sprintf("emulated storage error: %v", err.err) +} + +type storageLock struct { + s *Storage + l storage.Locker +} + +func (l storageLock) Unlock() { + l.l.Unlock() + l.s.logI("storage lock released") +} + +type reader struct { + s *Storage + fd storage.FileDesc + storage.Reader +} + +func (r *reader) Read(p []byte) (n int, err error) { + err = r.s.emulateError(ModeRead, r.fd.Type) + if err == nil { + r.s.stall(ModeRead, r.fd.Type) + n, err = r.Reader.Read(p) + } + r.s.count(ModeRead, r.fd.Type, n) + if err != nil && err != io.EOF { + r.s.logI("read error, fd=%s n=%d err=%v", r.fd, n, err) + } + return +} + +func (r *reader) ReadAt(p []byte, off int64) (n int, err error) { + err = r.s.emulateError(ModeRead, r.fd.Type) + if err == nil { + r.s.stall(ModeRead, r.fd.Type) + n, err = r.Reader.ReadAt(p, off) + } + r.s.count(ModeRead, r.fd.Type, n) + if err != nil && err != io.EOF { + r.s.logI("readAt error, fd=%s offset=%d n=%d err=%v", r.fd, off, n, err) + } + return +} + +func (r *reader) Close() (err error) { + return r.s.fileClose(r.fd, r.Reader) +} + +type writer struct { + s *Storage + fd storage.FileDesc + storage.Writer +} + +func (w *writer) Write(p []byte) (n int, err error) { + err = w.s.emulateError(ModeWrite, w.fd.Type) + if err == nil { + w.s.stall(ModeWrite, w.fd.Type) + n, err = w.Writer.Write(p) + } + w.s.count(ModeWrite, w.fd.Type, n) + if err != nil && err != io.EOF { + w.s.logI("write error, fd=%s n=%d err=%v", w.fd, n, err) + } + return +} + +func (w *writer) Sync() (err error) { + err = w.s.emulateError(ModeSync, w.fd.Type) + if err == nil { + w.s.stall(ModeSync, w.fd.Type) + err = w.Writer.Sync() + } + w.s.count(ModeSync, w.fd.Type, 0) + if err != nil { + w.s.logI("sync error, fd=%s err=%v", w.fd, err) + } + return +} + +func (w *writer) Close() (err error) { + return w.s.fileClose(w.fd, w.Writer) +} + +type Storage struct { + storage.Storage + path string + onClose func() (preserve bool, err error) + onLog func(str string) + + lmu sync.Mutex + lb bytes.Buffer + + mu sync.Mutex + rand *rand.Rand + // Open files, true=writer, false=reader + opens map[uint64]bool + counters [flattenCount]int + bytesCounter [flattenCount]int64 + emulatedError [flattenCount]error + emulatedErrorOnce [flattenCount]bool + emulatedRandomError [flattenCount]error + emulatedRandomErrorProb [flattenCount]float64 + stallCond sync.Cond + stalled [flattenCount]bool +} + +func (s *Storage) log(skip int, str string) { + s.lmu.Lock() + defer s.lmu.Unlock() + _, file, line, ok := runtime.Caller(skip + 2) + if ok { + // Truncate file name at last file name separator. + if index := strings.LastIndex(file, "/"); index >= 0 { + file = file[index+1:] + } else if index = strings.LastIndex(file, "\\"); index >= 0 { + file = file[index+1:] + } + } else { + file = "???" + line = 1 + } + fmt.Fprintf(&s.lb, "%s:%d: ", file, line) + lines := strings.Split(str, "\n") + if l := len(lines); l > 1 && lines[l-1] == "" { + lines = lines[:l-1] + } + for i, line := range lines { + if i > 0 { + s.lb.WriteString("\n\t") + } + s.lb.WriteString(line) + } + if s.onLog != nil { + s.onLog(s.lb.String()) + s.lb.Reset() + } else { + s.lb.WriteByte('\n') + } +} + +func (s *Storage) logISkip(skip int, format string, args ...interface{}) { + pc, _, _, ok := runtime.Caller(skip + 1) + if ok { + if f := runtime.FuncForPC(pc); f != nil { + fname := f.Name() + if index := strings.LastIndex(fname, "."); index >= 0 { + fname = fname[index+1:] + } + format = fname + ": " + format + } + } + s.log(skip+1, fmt.Sprintf(format, args...)) +} + +func (s *Storage) logI(format string, args ...interface{}) { + s.logISkip(1, format, args...) +} + +func (s *Storage) OnLog(onLog func(log string)) { + s.lmu.Lock() + s.onLog = onLog + if s.lb.Len() != 0 { + log := s.lb.String() + s.onLog(log[:len(log)-1]) + s.lb.Reset() + } + s.lmu.Unlock() +} + +func (s *Storage) Log(str string) { + s.log(1, "Log: "+str) + s.Storage.Log(str) +} + +func (s *Storage) Lock() (l storage.Locker, err error) { + l, err = s.Storage.Lock() + if err != nil { + s.logI("storage locking failed, err=%v", err) + } else { + s.logI("storage locked") + l = storageLock{s, l} + } + return +} + +func (s *Storage) List(t storage.FileType) (fds []storage.FileDesc, err error) { + fds, err = s.Storage.List(t) + if err != nil { + s.logI("list failed, err=%v", err) + return + } + s.logI("list, type=0x%x count=%d", int(t), len(fds)) + return +} + +func (s *Storage) GetMeta() (fd storage.FileDesc, err error) { + fd, err = s.Storage.GetMeta() + if err != nil { + if !os.IsNotExist(err) { + s.logI("get meta failed, err=%v", err) + } + return + } + s.logI("get meta, fd=%s", fd) + return +} + +func (s *Storage) SetMeta(fd storage.FileDesc) error { + ExpectWithOffset(1, fd.Type).To(Equal(storage.TypeManifest)) + err := s.Storage.SetMeta(fd) + if err != nil { + s.logI("set meta failed, fd=%s err=%v", fd, err) + } else { + s.logI("set meta, fd=%s", fd) + } + return err +} + +func (s *Storage) fileClose(fd storage.FileDesc, closer io.Closer) (err error) { + err = s.emulateError(ModeClose, fd.Type) + if err == nil { + s.stall(ModeClose, fd.Type) + } + x := packFile(fd) + s.mu.Lock() + defer s.mu.Unlock() + if err == nil { + ExpectWithOffset(2, s.opens).To(HaveKey(x), "File closed, fd=%s", fd) + err = closer.Close() + } + s.countNB(ModeClose, fd.Type, 0) + writer := s.opens[x] + if err != nil { + s.logISkip(1, "file close failed, fd=%s writer=%v err=%v", fd, writer, err) + } else { + s.logISkip(1, "file closed, fd=%s writer=%v", fd, writer) + delete(s.opens, x) + } + return +} + +func (s *Storage) assertOpen(fd storage.FileDesc) { + x := packFile(fd) + ExpectWithOffset(2, s.opens).NotTo(HaveKey(x), "File open, fd=%s writer=%v", fd, s.opens[x]) +} + +func (s *Storage) Open(fd storage.FileDesc) (r storage.Reader, err error) { + err = s.emulateError(ModeOpen, fd.Type) + if err == nil { + s.stall(ModeOpen, fd.Type) + } + s.mu.Lock() + defer s.mu.Unlock() + if err == nil { + s.assertOpen(fd) + s.countNB(ModeOpen, fd.Type, 0) + r, err = s.Storage.Open(fd) + } + if err != nil { + s.logI("file open failed, fd=%s err=%v", fd, err) + } else { + s.logI("file opened, fd=%s", fd) + s.opens[packFile(fd)] = false + r = &reader{s, fd, r} + } + return +} + +func (s *Storage) Create(fd storage.FileDesc) (w storage.Writer, err error) { + err = s.emulateError(ModeCreate, fd.Type) + if err == nil { + s.stall(ModeCreate, fd.Type) + } + s.mu.Lock() + defer s.mu.Unlock() + if err == nil { + s.assertOpen(fd) + s.countNB(ModeCreate, fd.Type, 0) + w, err = s.Storage.Create(fd) + } + if err != nil { + s.logI("file create failed, fd=%s err=%v", fd, err) + } else { + s.logI("file created, fd=%s", fd) + s.opens[packFile(fd)] = true + w = &writer{s, fd, w} + } + return +} + +func (s *Storage) Remove(fd storage.FileDesc) (err error) { + err = s.emulateError(ModeRemove, fd.Type) + if err == nil { + s.stall(ModeRemove, fd.Type) + } + s.mu.Lock() + defer s.mu.Unlock() + if err == nil { + s.assertOpen(fd) + s.countNB(ModeRemove, fd.Type, 0) + err = s.Storage.Remove(fd) + } + if err != nil { + s.logI("file remove failed, fd=%s err=%v", fd, err) + } else { + s.logI("file removed, fd=%s", fd) + } + return +} + +func (s *Storage) ForceRemove(fd storage.FileDesc) (err error) { + s.countNB(ModeRemove, fd.Type, 0) + if err = s.Storage.Remove(fd); err != nil { + s.logI("file remove failed (forced), fd=%s err=%v", fd, err) + } else { + s.logI("file removed (forced), fd=%s", fd) + } + return +} + +func (s *Storage) Rename(oldfd, newfd storage.FileDesc) (err error) { + err = s.emulateError(ModeRename, oldfd.Type) + if err == nil { + s.stall(ModeRename, oldfd.Type) + } + s.mu.Lock() + defer s.mu.Unlock() + if err == nil { + s.assertOpen(oldfd) + s.assertOpen(newfd) + s.countNB(ModeRename, oldfd.Type, 0) + err = s.Storage.Rename(oldfd, newfd) + } + if err != nil { + s.logI("file rename failed, oldfd=%s newfd=%s err=%v", oldfd, newfd, err) + } else { + s.logI("file renamed, oldfd=%s newfd=%s", oldfd, newfd) + } + return +} + +func (s *Storage) ForceRename(oldfd, newfd storage.FileDesc) (err error) { + s.countNB(ModeRename, oldfd.Type, 0) + if err = s.Storage.Rename(oldfd, newfd); err != nil { + s.logI("file rename failed (forced), oldfd=%s newfd=%s err=%v", oldfd, newfd, err) + } else { + s.logI("file renamed (forced), oldfd=%s newfd=%s", oldfd, newfd) + } + return +} + +func (s *Storage) openFiles() string { + out := "Open files:" + for x, writer := range s.opens { + fd := unpackFile(x) + out += fmt.Sprintf("\n · fd=%s writer=%v", fd, writer) + } + return out +} + +func (s *Storage) CloseCheck() { + s.mu.Lock() + defer s.mu.Unlock() + ExpectWithOffset(1, s.opens).To(BeEmpty(), s.openFiles()) +} + +func (s *Storage) OnClose(onClose func() (preserve bool, err error)) { + s.mu.Lock() + s.onClose = onClose + s.mu.Unlock() +} + +func (s *Storage) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + ExpectWithOffset(1, s.opens).To(BeEmpty(), s.openFiles()) + err := s.Storage.Close() + if err != nil { + s.logI("storage closing failed, err=%v", err) + } else { + s.logI("storage closed") + } + var preserve bool + if s.onClose != nil { + var err0 error + if preserve, err0 = s.onClose(); err0 != nil { + s.logI("onClose error, err=%v", err0) + } + } + if s.path != "" { + if storageKeepFS || preserve { + s.logI("storage is preserved, path=%v", s.path) + } else { + if err1 := os.RemoveAll(s.path); err1 != nil { + s.logI("cannot remove storage, err=%v", err1) + } else { + s.logI("storage has been removed") + } + } + } + return err +} + +func (s *Storage) countNB(m StorageMode, t storage.FileType, n int) { + s.counters[flattenType(m, t)]++ + s.bytesCounter[flattenType(m, t)] += int64(n) +} + +func (s *Storage) count(m StorageMode, t storage.FileType, n int) { + s.mu.Lock() + defer s.mu.Unlock() + s.countNB(m, t, n) +} + +func (s *Storage) ResetCounter(m StorageMode, t storage.FileType) { + for _, x := range listFlattenType(m, t) { + s.counters[x] = 0 + s.bytesCounter[x] = 0 + } +} + +func (s *Storage) Counter(m StorageMode, t storage.FileType) (count int, bytes int64) { + for _, x := range listFlattenType(m, t) { + count += s.counters[x] + bytes += s.bytesCounter[x] + } + return +} + +func (s *Storage) emulateError(m StorageMode, t storage.FileType) error { + s.mu.Lock() + defer s.mu.Unlock() + x := flattenType(m, t) + if err := s.emulatedError[x]; err != nil { + if s.emulatedErrorOnce[x] { + s.emulatedError[x] = nil + } + return emulatedError{err} + } + if err := s.emulatedRandomError[x]; err != nil && s.rand.Float64() < s.emulatedRandomErrorProb[x] { + return emulatedError{err} + } + return nil +} + +func (s *Storage) EmulateError(m StorageMode, t storage.FileType, err error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.emulatedError[x] = err + s.emulatedErrorOnce[x] = false + } +} + +func (s *Storage) EmulateErrorOnce(m StorageMode, t storage.FileType, err error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.emulatedError[x] = err + s.emulatedErrorOnce[x] = true + } +} + +func (s *Storage) EmulateRandomError(m StorageMode, t storage.FileType, prob float64, err error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.emulatedRandomError[x] = err + s.emulatedRandomErrorProb[x] = prob + } +} + +func (s *Storage) stall(m StorageMode, t storage.FileType) { + x := flattenType(m, t) + s.mu.Lock() + defer s.mu.Unlock() + for s.stalled[x] { + s.stallCond.Wait() + } +} + +func (s *Storage) Stall(m StorageMode, t storage.FileType) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.stalled[x] = true + } +} + +func (s *Storage) Release(m StorageMode, t storage.FileType) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.stalled[x] = false + } + s.stallCond.Broadcast() +} + +func NewStorage() *Storage { + var ( + stor storage.Storage + path string + ) + if storageUseFS { + for { + storageMu.Lock() + num := storageNum + storageNum++ + storageMu.Unlock() + path = filepath.Join(os.TempDir(), fmt.Sprintf("goleveldb-test%d0%d0%d", os.Getuid(), os.Getpid(), num)) + if _, err := os.Stat(path); os.IsNotExist(err) { + stor, err = storage.OpenFile(path, false) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "creating storage at %s", path) + break + } + } + } else { + stor = storage.NewMemStorage() + } + s := &Storage{ + Storage: stor, + path: path, + rand: NewRand(), + opens: make(map[uint64]bool), + } + s.stallCond.L = &s.mu + if s.path != "" { + s.logI("using FS storage") + s.logI("storage path: %s", s.path) + } else { + s.logI("using MEM storage") + } + return s +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil/util.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/util.go new file mode 100644 index 0000000..97c5294 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil/util.go @@ -0,0 +1,171 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "bytes" + "flag" + "math/rand" + "reflect" + "sync" + + "github.com/onsi/ginkgo/config" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +var ( + runfn = make(map[string][]func()) + runmu sync.Mutex +) + +func Defer(args ...interface{}) bool { + var ( + group string + fn func() + ) + for _, arg := range args { + v := reflect.ValueOf(arg) + switch v.Kind() { + case reflect.String: + group = v.String() + case reflect.Func: + r := reflect.ValueOf(&fn).Elem() + r.Set(v) + } + } + if fn != nil { + runmu.Lock() + runfn[group] = append(runfn[group], fn) + runmu.Unlock() + } + return true +} + +func RunDefer(groups ...string) bool { + if len(groups) == 0 { + groups = append(groups, "") + } + runmu.Lock() + var runfn_ []func() + for _, group := range groups { + runfn_ = append(runfn_, runfn[group]...) + delete(runfn, group) + } + runmu.Unlock() + for _, fn := range runfn_ { + fn() + } + return runfn_ != nil +} + +func RandomSeed() int64 { + if !flag.Parsed() { + panic("random seed not initialized") + } + return config.GinkgoConfig.RandomSeed +} + +func NewRand() *rand.Rand { + return rand.New(rand.NewSource(RandomSeed())) +} + +var cmp = comparer.DefaultComparer + +func BytesSeparator(a, b []byte) []byte { + if bytes.Equal(a, b) { + return b + } + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for ; i < n && (a[i] == b[i]); i++ { + } + x := append([]byte{}, a[:i]...) + if i < n { + if c := a[i] + 1; c < b[i] { + return append(x, c) + } + x = append(x, a[i]) + i++ + } + for ; i < len(a); i++ { + if c := a[i]; c < 0xff { + return append(x, c+1) + } else { + x = append(x, c) + } + } + if len(b) > i && b[i] > 0 { + return append(x, b[i]-1) + } + return append(x, 'x') +} + +func BytesAfter(b []byte) []byte { + var x []byte + for _, c := range b { + if c < 0xff { + return append(x, c+1) + } else { + x = append(x, c) + } + } + return append(x, 'x') +} + +func RandomIndex(rnd *rand.Rand, n, round int, fn func(i int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + fn(rnd.Intn(n)) + } + return +} + +func ShuffledIndex(rnd *rand.Rand, n, round int, fn func(i int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + for _, i := range rnd.Perm(n) { + fn(i) + } + } + return +} + +func RandomRange(rnd *rand.Rand, n, round int, fn func(start, limit int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + start := rnd.Intn(n) + length := 0 + if j := n - start; j > 0 { + length = rnd.Intn(j) + } + fn(start, start+length) + } + return +} + +func Max(x, y int) int { + if x > y { + return x + } + return y +} + +func Min(x, y int) int { + if x < y { + return x + } + return y +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/testutil_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/testutil_test.go new file mode 100644 index 0000000..c8cb44c --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/testutil_test.go @@ -0,0 +1,91 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type testingDB struct { + *DB + ro *opt.ReadOptions + wo *opt.WriteOptions + stor *testutil.Storage +} + +func (t *testingDB) TestPut(key []byte, value []byte) error { + return t.Put(key, value, t.wo) +} + +func (t *testingDB) TestDelete(key []byte) error { + return t.Delete(key, t.wo) +} + +func (t *testingDB) TestGet(key []byte) (value []byte, err error) { + return t.Get(key, t.ro) +} + +func (t *testingDB) TestHas(key []byte) (ret bool, err error) { + return t.Has(key, t.ro) +} + +func (t *testingDB) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.NewIterator(slice, t.ro) +} + +func (t *testingDB) TestClose() { + err := t.Close() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + err = t.stor.Close() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) +} + +func newTestingDB(o *opt.Options, ro *opt.ReadOptions, wo *opt.WriteOptions) *testingDB { + stor := testutil.NewStorage() + db, err := Open(stor, o) + // FIXME: This may be called from outside It, which may cause panic. + Expect(err).NotTo(HaveOccurred()) + return &testingDB{ + DB: db, + ro: ro, + wo: wo, + stor: stor, + } +} + +type testingTransaction struct { + *Transaction + ro *opt.ReadOptions + wo *opt.WriteOptions +} + +func (t *testingTransaction) TestPut(key []byte, value []byte) error { + return t.Put(key, value, t.wo) +} + +func (t *testingTransaction) TestDelete(key []byte) error { + return t.Delete(key, t.wo) +} + +func (t *testingTransaction) TestGet(key []byte) (value []byte, err error) { + return t.Get(key, t.ro) +} + +func (t *testingTransaction) TestHas(key []byte) (ret bool, err error) { + return t.Has(key, t.ro) +} + +func (t *testingTransaction) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.NewIterator(slice, t.ro) +} + +func (t *testingTransaction) TestClose() {} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util.go b/vendor/github.com/syndtr/goleveldb/leveldb/util.go new file mode 100644 index 0000000..0e2b519 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util.go @@ -0,0 +1,98 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sort" + + "github.com/syndtr/goleveldb/leveldb/storage" +) + +func shorten(str string) string { + if len(str) <= 8 { + return str + } + return str[:3] + ".." + str[len(str)-3:] +} + +var bunits = [...]string{"", "Ki", "Mi", "Gi", "Ti"} + +func shortenb(bytes int) string { + i := 0 + for ; bytes > 1024 && i < 4; i++ { + bytes /= 1024 + } + return fmt.Sprintf("%d%sB", bytes, bunits[i]) +} + +func sshortenb(bytes int) string { + if bytes == 0 { + return "~" + } + sign := "+" + if bytes < 0 { + sign = "-" + bytes *= -1 + } + i := 0 + for ; bytes > 1024 && i < 4; i++ { + bytes /= 1024 + } + return fmt.Sprintf("%s%d%sB", sign, bytes, bunits[i]) +} + +func sint(x int) string { + if x == 0 { + return "~" + } + sign := "+" + if x < 0 { + sign = "-" + x *= -1 + } + return fmt.Sprintf("%s%d", sign, x) +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +type fdSorter []storage.FileDesc + +func (p fdSorter) Len() int { + return len(p) +} + +func (p fdSorter) Less(i, j int) bool { + return p[i].Num < p[j].Num +} + +func (p fdSorter) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +func sortFds(fds []storage.FileDesc) { + sort.Sort(fdSorter(fds)) +} + +func ensureBuffer(b []byte, n int) []byte { + if cap(b) < n { + return make([]byte, n) + } + return b[:n] +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer.go new file mode 100644 index 0000000..21de242 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer.go @@ -0,0 +1,293 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +// This a copy of Go std bytes.Buffer with some modification +// and some features stripped. + +import ( + "bytes" + "io" +) + +// A Buffer is a variable-sized buffer of bytes with Read and Write methods. +// The zero value for Buffer is an empty buffer ready to use. +type Buffer struct { + buf []byte // contents are the bytes buf[off : len(buf)] + off int // read at &buf[off], write at &buf[len(buf)] + bootstrap [64]byte // memory to hold first slice; helps small buffers (Printf) avoid allocation. +} + +// Bytes returns a slice of the contents of the unread portion of the buffer; +// len(b.Bytes()) == b.Len(). If the caller changes the contents of the +// returned slice, the contents of the buffer will change provided there +// are no intervening method calls on the Buffer. +func (b *Buffer) Bytes() []byte { return b.buf[b.off:] } + +// String returns the contents of the unread portion of the buffer +// as a string. If the Buffer is a nil pointer, it returns "". +func (b *Buffer) String() string { + if b == nil { + // Special case, useful in debugging. + return "" + } + return string(b.buf[b.off:]) +} + +// Len returns the number of bytes of the unread portion of the buffer; +// b.Len() == len(b.Bytes()). +func (b *Buffer) Len() int { return len(b.buf) - b.off } + +// Truncate discards all but the first n unread bytes from the buffer. +// It panics if n is negative or greater than the length of the buffer. +func (b *Buffer) Truncate(n int) { + switch { + case n < 0 || n > b.Len(): + panic("leveldb/util.Buffer: truncation out of range") + case n == 0: + // Reuse buffer space. + b.off = 0 + } + b.buf = b.buf[0 : b.off+n] +} + +// Reset resets the buffer so it has no content. +// b.Reset() is the same as b.Truncate(0). +func (b *Buffer) Reset() { b.Truncate(0) } + +// grow grows the buffer to guarantee space for n more bytes. +// It returns the index where bytes should be written. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) grow(n int) int { + m := b.Len() + // If buffer is empty, reset to recover space. + if m == 0 && b.off != 0 { + b.Truncate(0) + } + if len(b.buf)+n > cap(b.buf) { + var buf []byte + if b.buf == nil && n <= len(b.bootstrap) { + buf = b.bootstrap[0:] + } else if m+n <= cap(b.buf)/2 { + // We can slide things down instead of allocating a new + // slice. We only need m+n <= cap(b.buf) to slide, but + // we instead let capacity get twice as large so we + // don't spend all our time copying. + copy(b.buf[:], b.buf[b.off:]) + buf = b.buf[:m] + } else { + // not enough space anywhere + buf = makeSlice(2*cap(b.buf) + n) + copy(buf, b.buf[b.off:]) + } + b.buf = buf + b.off = 0 + } + b.buf = b.buf[0 : b.off+m+n] + return b.off + m +} + +// Alloc allocs n bytes of slice from the buffer, growing the buffer as +// needed. If n is negative, Alloc will panic. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) Alloc(n int) []byte { + if n < 0 { + panic("leveldb/util.Buffer.Alloc: negative count") + } + m := b.grow(n) + return b.buf[m:] +} + +// Grow grows the buffer's capacity, if necessary, to guarantee space for +// another n bytes. After Grow(n), at least n bytes can be written to the +// buffer without another allocation. +// If n is negative, Grow will panic. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) Grow(n int) { + if n < 0 { + panic("leveldb/util.Buffer.Grow: negative count") + } + m := b.grow(n) + b.buf = b.buf[0:m] +} + +// Write appends the contents of p to the buffer, growing the buffer as +// needed. The return value n is the length of p; err is always nil. If the +// buffer becomes too large, Write will panic with bytes.ErrTooLarge. +func (b *Buffer) Write(p []byte) (n int, err error) { + m := b.grow(len(p)) + return copy(b.buf[m:], p), nil +} + +// MinRead is the minimum slice size passed to a Read call by +// Buffer.ReadFrom. As long as the Buffer has at least MinRead bytes beyond +// what is required to hold the contents of r, ReadFrom will not grow the +// underlying buffer. +const MinRead = 512 + +// ReadFrom reads data from r until EOF and appends it to the buffer, growing +// the buffer as needed. The return value n is the number of bytes read. Any +// error except io.EOF encountered during the read is also returned. If the +// buffer becomes too large, ReadFrom will panic with bytes.ErrTooLarge. +func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) { + // If buffer is empty, reset to recover space. + if b.off >= len(b.buf) { + b.Truncate(0) + } + for { + if free := cap(b.buf) - len(b.buf); free < MinRead { + // not enough space at end + newBuf := b.buf + if b.off+free < MinRead { + // not enough space using beginning of buffer; + // double buffer capacity + newBuf = makeSlice(2*cap(b.buf) + MinRead) + } + copy(newBuf, b.buf[b.off:]) + b.buf = newBuf[:len(b.buf)-b.off] + b.off = 0 + } + m, e := r.Read(b.buf[len(b.buf):cap(b.buf)]) + b.buf = b.buf[0 : len(b.buf)+m] + n += int64(m) + if e == io.EOF { + break + } + if e != nil { + return n, e + } + } + return n, nil // err is EOF, so return nil explicitly +} + +// makeSlice allocates a slice of size n. If the allocation fails, it panics +// with bytes.ErrTooLarge. +func makeSlice(n int) []byte { + // If the make fails, give a known error. + defer func() { + if recover() != nil { + panic(bytes.ErrTooLarge) + } + }() + return make([]byte, n) +} + +// WriteTo writes data to w until the buffer is drained or an error occurs. +// The return value n is the number of bytes written; it always fits into an +// int, but it is int64 to match the io.WriterTo interface. Any error +// encountered during the write is also returned. +func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) { + if b.off < len(b.buf) { + nBytes := b.Len() + m, e := w.Write(b.buf[b.off:]) + if m > nBytes { + panic("leveldb/util.Buffer.WriteTo: invalid Write count") + } + b.off += m + n = int64(m) + if e != nil { + return n, e + } + // all bytes should have been written, by definition of + // Write method in io.Writer + if m != nBytes { + return n, io.ErrShortWrite + } + } + // Buffer is now empty; reset. + b.Truncate(0) + return +} + +// WriteByte appends the byte c to the buffer, growing the buffer as needed. +// The returned error is always nil, but is included to match bufio.Writer's +// WriteByte. If the buffer becomes too large, WriteByte will panic with +// bytes.ErrTooLarge. +func (b *Buffer) WriteByte(c byte) error { + m := b.grow(1) + b.buf[m] = c + return nil +} + +// Read reads the next len(p) bytes from the buffer or until the buffer +// is drained. The return value n is the number of bytes read. If the +// buffer has no data to return, err is io.EOF (unless len(p) is zero); +// otherwise it is nil. +func (b *Buffer) Read(p []byte) (n int, err error) { + if b.off >= len(b.buf) { + // Buffer is empty, reset to recover space. + b.Truncate(0) + if len(p) == 0 { + return + } + return 0, io.EOF + } + n = copy(p, b.buf[b.off:]) + b.off += n + return +} + +// Next returns a slice containing the next n bytes from the buffer, +// advancing the buffer as if the bytes had been returned by Read. +// If there are fewer than n bytes in the buffer, Next returns the entire buffer. +// The slice is only valid until the next call to a read or write method. +func (b *Buffer) Next(n int) []byte { + m := b.Len() + if n > m { + n = m + } + data := b.buf[b.off : b.off+n] + b.off += n + return data +} + +// ReadByte reads and returns the next byte from the buffer. +// If no byte is available, it returns error io.EOF. +func (b *Buffer) ReadByte() (c byte, err error) { + if b.off >= len(b.buf) { + // Buffer is empty, reset to recover space. + b.Truncate(0) + return 0, io.EOF + } + c = b.buf[b.off] + b.off++ + return c, nil +} + +// ReadBytes reads until the first occurrence of delim in the input, +// returning a slice containing the data up to and including the delimiter. +// If ReadBytes encounters an error before finding a delimiter, +// it returns the data read before the error and the error itself (often io.EOF). +// ReadBytes returns err != nil if and only if the returned data does not end in +// delim. +func (b *Buffer) ReadBytes(delim byte) (line []byte, err error) { + slice, err := b.readSlice(delim) + // return a copy of slice. The buffer's backing array may + // be overwritten by later calls. + line = append(line, slice...) + return +} + +// readSlice is like ReadBytes but returns a reference to internal buffer data. +func (b *Buffer) readSlice(delim byte) (line []byte, err error) { + i := bytes.IndexByte(b.buf[b.off:], delim) + end := b.off + i + 1 + if i < 0 { + end = len(b.buf) + err = io.EOF + } + line = b.buf[b.off:end] + b.off = end + return line, err +} + +// NewBuffer creates and initializes a new Buffer using buf as its initial +// contents. It is intended to prepare a Buffer to read existing data. It +// can also be used to size the internal buffer for writing. To do that, +// buf should have the desired capacity but a length of zero. +// +// In most cases, new(Buffer) (or just declaring a Buffer variable) is +// sufficient to initialize a Buffer. +func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go new file mode 100644 index 0000000..2f3db97 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go @@ -0,0 +1,239 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +type buffer struct { + b []byte + miss int +} + +// BufferPool is a 'buffer pool'. +type BufferPool struct { + pool [6]chan []byte + size [5]uint32 + sizeMiss [5]uint32 + sizeHalf [5]uint32 + baseline [4]int + baseline0 int + + mu sync.RWMutex + closed bool + closeC chan struct{} + + get uint32 + put uint32 + half uint32 + less uint32 + equal uint32 + greater uint32 + miss uint32 +} + +func (p *BufferPool) poolNum(n int) int { + if n <= p.baseline0 && n > p.baseline0/2 { + return 0 + } + for i, x := range p.baseline { + if n <= x { + return i + 1 + } + } + return len(p.baseline) + 1 +} + +// Get returns buffer with length of n. +func (p *BufferPool) Get(n int) []byte { + if p == nil { + return make([]byte, n) + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return make([]byte, n) + } + + atomic.AddUint32(&p.get, 1) + + poolNum := p.poolNum(n) + pool := p.pool[poolNum] + if poolNum == 0 { + // Fast path. + select { + case b := <-pool: + switch { + case cap(b) > n: + if cap(b)-n >= n { + atomic.AddUint32(&p.half, 1) + select { + case pool <- b: + default: + } + return make([]byte, n) + } else { + atomic.AddUint32(&p.less, 1) + return b[:n] + } + case cap(b) == n: + atomic.AddUint32(&p.equal, 1) + return b[:n] + default: + atomic.AddUint32(&p.greater, 1) + } + default: + atomic.AddUint32(&p.miss, 1) + } + + return make([]byte, n, p.baseline0) + } else { + sizePtr := &p.size[poolNum-1] + + select { + case b := <-pool: + switch { + case cap(b) > n: + if cap(b)-n >= n { + atomic.AddUint32(&p.half, 1) + sizeHalfPtr := &p.sizeHalf[poolNum-1] + if atomic.AddUint32(sizeHalfPtr, 1) == 20 { + atomic.StoreUint32(sizePtr, uint32(cap(b)/2)) + atomic.StoreUint32(sizeHalfPtr, 0) + } else { + select { + case pool <- b: + default: + } + } + return make([]byte, n) + } else { + atomic.AddUint32(&p.less, 1) + return b[:n] + } + case cap(b) == n: + atomic.AddUint32(&p.equal, 1) + return b[:n] + default: + atomic.AddUint32(&p.greater, 1) + if uint32(cap(b)) >= atomic.LoadUint32(sizePtr) { + select { + case pool <- b: + default: + } + } + } + default: + atomic.AddUint32(&p.miss, 1) + } + + if size := atomic.LoadUint32(sizePtr); uint32(n) > size { + if size == 0 { + atomic.CompareAndSwapUint32(sizePtr, 0, uint32(n)) + } else { + sizeMissPtr := &p.sizeMiss[poolNum-1] + if atomic.AddUint32(sizeMissPtr, 1) == 20 { + atomic.StoreUint32(sizePtr, uint32(n)) + atomic.StoreUint32(sizeMissPtr, 0) + } + } + return make([]byte, n) + } else { + return make([]byte, n, size) + } + } +} + +// Put adds given buffer to the pool. +func (p *BufferPool) Put(b []byte) { + if p == nil { + return + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return + } + + atomic.AddUint32(&p.put, 1) + + pool := p.pool[p.poolNum(cap(b))] + select { + case pool <- b: + default: + } + +} + +func (p *BufferPool) Close() { + if p == nil { + return + } + + p.mu.Lock() + if !p.closed { + p.closed = true + p.closeC <- struct{}{} + } + p.mu.Unlock() +} + +func (p *BufferPool) String() string { + if p == nil { + return "" + } + + return fmt.Sprintf("BufferPool{B·%d Z·%v Zm·%v Zh·%v G·%d P·%d H·%d <·%d =·%d >·%d M·%d}", + p.baseline0, p.size, p.sizeMiss, p.sizeHalf, p.get, p.put, p.half, p.less, p.equal, p.greater, p.miss) +} + +func (p *BufferPool) drain() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + for _, ch := range p.pool { + select { + case <-ch: + default: + } + } + case <-p.closeC: + close(p.closeC) + for _, ch := range p.pool { + close(ch) + } + return + } + } +} + +// NewBufferPool creates a new initialized 'buffer pool'. +func NewBufferPool(baseline int) *BufferPool { + if baseline <= 0 { + panic("baseline can't be <= 0") + } + p := &BufferPool{ + baseline0: baseline, + baseline: [...]int{baseline / 4, baseline / 2, baseline * 2, baseline * 4}, + closeC: make(chan struct{}, 1), + } + for i, cap := range []int{2, 2, 4, 4, 2, 1} { + p.pool[i] = make(chan []byte, cap) + } + go p.drain() + return p +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go new file mode 100644 index 0000000..772c2ac --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go @@ -0,0 +1,369 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +import ( + "bytes" + "io" + "math/rand" + "runtime" + "testing" +) + +const N = 10000 // make this bigger for a larger (and slower) test +var data string // test data for write tests +var testBytes []byte // test data; same as data but as a slice. + +func init() { + testBytes = make([]byte, N) + for i := 0; i < N; i++ { + testBytes[i] = 'a' + byte(i%26) + } + data = string(testBytes) +} + +// Verify that contents of buf match the string s. +func check(t *testing.T, testname string, buf *Buffer, s string) { + bytes := buf.Bytes() + str := buf.String() + if buf.Len() != len(bytes) { + t.Errorf("%s: buf.Len() == %d, len(buf.Bytes()) == %d", testname, buf.Len(), len(bytes)) + } + + if buf.Len() != len(str) { + t.Errorf("%s: buf.Len() == %d, len(buf.String()) == %d", testname, buf.Len(), len(str)) + } + + if buf.Len() != len(s) { + t.Errorf("%s: buf.Len() == %d, len(s) == %d", testname, buf.Len(), len(s)) + } + + if string(bytes) != s { + t.Errorf("%s: string(buf.Bytes()) == %q, s == %q", testname, string(bytes), s) + } +} + +// Fill buf through n writes of byte slice fub. +// The initial contents of buf corresponds to the string s; +// the result is the final contents of buf returned as a string. +func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub []byte) string { + check(t, testname+" (fill 1)", buf, s) + for ; n > 0; n-- { + m, err := buf.Write(fub) + if m != len(fub) { + t.Errorf(testname+" (fill 2): m == %d, expected %d", m, len(fub)) + } + if err != nil { + t.Errorf(testname+" (fill 3): err should always be nil, found err == %s", err) + } + s += string(fub) + check(t, testname+" (fill 4)", buf, s) + } + return s +} + +func TestNewBuffer(t *testing.T) { + buf := NewBuffer(testBytes) + check(t, "NewBuffer", buf, data) +} + +// Empty buf through repeated reads into fub. +// The initial contents of buf corresponds to the string s. +func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) { + check(t, testname+" (empty 1)", buf, s) + + for { + n, err := buf.Read(fub) + if n == 0 { + break + } + if err != nil { + t.Errorf(testname+" (empty 2): err should always be nil, found err == %s", err) + } + s = s[n:] + check(t, testname+" (empty 3)", buf, s) + } + + check(t, testname+" (empty 4)", buf, "") +} + +func TestBasicOperations(t *testing.T) { + var buf Buffer + + for i := 0; i < 5; i++ { + check(t, "TestBasicOperations (1)", &buf, "") + + buf.Reset() + check(t, "TestBasicOperations (2)", &buf, "") + + buf.Truncate(0) + check(t, "TestBasicOperations (3)", &buf, "") + + n, err := buf.Write([]byte(data[0:1])) + if n != 1 { + t.Errorf("wrote 1 byte, but n == %d", n) + } + if err != nil { + t.Errorf("err should always be nil, but err == %s", err) + } + check(t, "TestBasicOperations (4)", &buf, "a") + + buf.WriteByte(data[1]) + check(t, "TestBasicOperations (5)", &buf, "ab") + + n, err = buf.Write([]byte(data[2:26])) + if n != 24 { + t.Errorf("wrote 25 bytes, but n == %d", n) + } + check(t, "TestBasicOperations (6)", &buf, string(data[0:26])) + + buf.Truncate(26) + check(t, "TestBasicOperations (7)", &buf, string(data[0:26])) + + buf.Truncate(20) + check(t, "TestBasicOperations (8)", &buf, string(data[0:20])) + + empty(t, "TestBasicOperations (9)", &buf, string(data[0:20]), make([]byte, 5)) + empty(t, "TestBasicOperations (10)", &buf, "", make([]byte, 100)) + + buf.WriteByte(data[1]) + c, err := buf.ReadByte() + if err != nil { + t.Error("ReadByte unexpected eof") + } + if c != data[1] { + t.Errorf("ReadByte wrong value c=%v", c) + } + c, err = buf.ReadByte() + if err == nil { + t.Error("ReadByte unexpected not eof") + } + } +} + +func TestLargeByteWrites(t *testing.T) { + var buf Buffer + limit := 30 + if testing.Short() { + limit = 9 + } + for i := 3; i < limit; i += 3 { + s := fillBytes(t, "TestLargeWrites (1)", &buf, "", 5, testBytes) + empty(t, "TestLargeByteWrites (2)", &buf, s, make([]byte, len(data)/i)) + } + check(t, "TestLargeByteWrites (3)", &buf, "") +} + +func TestLargeByteReads(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestLargeReads (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + empty(t, "TestLargeReads (2)", &buf, s, make([]byte, len(data))) + } + check(t, "TestLargeByteReads (3)", &buf, "") +} + +func TestMixedReadsAndWrites(t *testing.T) { + var buf Buffer + s := "" + for i := 0; i < 50; i++ { + wlen := rand.Intn(len(data)) + s = fillBytes(t, "TestMixedReadsAndWrites (1)", &buf, s, 1, testBytes[0:wlen]) + rlen := rand.Intn(len(data)) + fub := make([]byte, rlen) + n, _ := buf.Read(fub) + s = s[n:] + } + empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len())) +} + +func TestNil(t *testing.T) { + var b *Buffer + if b.String() != "" { + t.Errorf("expected ; got %q", b.String()) + } +} + +func TestReadFrom(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestReadFrom (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + var b Buffer + b.ReadFrom(&buf) + empty(t, "TestReadFrom (2)", &b, s, make([]byte, len(data))) + } +} + +func TestWriteTo(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestWriteTo (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + var b Buffer + buf.WriteTo(&b) + empty(t, "TestWriteTo (2)", &b, s, make([]byte, len(data))) + } +} + +func TestNext(t *testing.T) { + b := []byte{0, 1, 2, 3, 4} + tmp := make([]byte, 5) + for i := 0; i <= 5; i++ { + for j := i; j <= 5; j++ { + for k := 0; k <= 6; k++ { + // 0 <= i <= j <= 5; 0 <= k <= 6 + // Check that if we start with a buffer + // of length j at offset i and ask for + // Next(k), we get the right bytes. + buf := NewBuffer(b[0:j]) + n, _ := buf.Read(tmp[0:i]) + if n != i { + t.Fatalf("Read %d returned %d", i, n) + } + bb := buf.Next(k) + want := k + if want > j-i { + want = j - i + } + if len(bb) != want { + t.Fatalf("in %d,%d: len(Next(%d)) == %d", i, j, k, len(bb)) + } + for l, v := range bb { + if v != byte(l+i) { + t.Fatalf("in %d,%d: Next(%d)[%d] = %d, want %d", i, j, k, l, v, l+i) + } + } + } + } + } +} + +var readBytesTests = []struct { + buffer string + delim byte + expected []string + err error +}{ + {"", 0, []string{""}, io.EOF}, + {"a\x00", 0, []string{"a\x00"}, nil}, + {"abbbaaaba", 'b', []string{"ab", "b", "b", "aaab"}, nil}, + {"hello\x01world", 1, []string{"hello\x01"}, nil}, + {"foo\nbar", 0, []string{"foo\nbar"}, io.EOF}, + {"alpha\nbeta\ngamma\n", '\n', []string{"alpha\n", "beta\n", "gamma\n"}, nil}, + {"alpha\nbeta\ngamma", '\n', []string{"alpha\n", "beta\n", "gamma"}, io.EOF}, +} + +func TestReadBytes(t *testing.T) { + for _, test := range readBytesTests { + buf := NewBuffer([]byte(test.buffer)) + var err error + for _, expected := range test.expected { + var bytes []byte + bytes, err = buf.ReadBytes(test.delim) + if string(bytes) != expected { + t.Errorf("expected %q, got %q", expected, bytes) + } + if err != nil { + break + } + } + if err != test.err { + t.Errorf("expected error %v, got %v", test.err, err) + } + } +} + +func TestGrow(t *testing.T) { + x := []byte{'x'} + y := []byte{'y'} + tmp := make([]byte, 72) + for _, startLen := range []int{0, 100, 1000, 10000, 100000} { + xBytes := bytes.Repeat(x, startLen) + for _, growLen := range []int{0, 100, 1000, 10000, 100000} { + buf := NewBuffer(xBytes) + // If we read, this affects buf.off, which is good to test. + readBytes, _ := buf.Read(tmp) + buf.Grow(growLen) + yBytes := bytes.Repeat(y, growLen) + // Check no allocation occurs in write, as long as we're single-threaded. + var m1, m2 runtime.MemStats + runtime.ReadMemStats(&m1) + buf.Write(yBytes) + runtime.ReadMemStats(&m2) + if runtime.GOMAXPROCS(-1) == 1 && m1.Mallocs != m2.Mallocs { + t.Errorf("allocation occurred during write") + } + // Check that buffer has correct data. + if !bytes.Equal(buf.Bytes()[0:startLen-readBytes], xBytes[readBytes:]) { + t.Errorf("bad initial data at %d %d", startLen, growLen) + } + if !bytes.Equal(buf.Bytes()[startLen-readBytes:startLen-readBytes+growLen], yBytes) { + t.Errorf("bad written data at %d %d", startLen, growLen) + } + } + } +} + +// TestReadEmptyAtEOF: Was a bug: used to give EOF reading empty slice at EOF. +func TestReadEmptyAtEOF(t *testing.T) { + b := new(Buffer) + slice := make([]byte, 0) + n, err := b.Read(slice) + if err != nil { + t.Errorf("read error: %v", err) + } + if n != 0 { + t.Errorf("wrong count; got %d want 0", n) + } +} + +// TestBufferGrowth tests that we occasionally compact. Issue 5154. +func TestBufferGrowth(t *testing.T) { + var b Buffer + buf := make([]byte, 1024) + b.Write(buf[0:1]) + var cap0 int + for i := 0; i < 5<<10; i++ { + b.Write(buf) + b.Read(buf) + if i == 0 { + cap0 = cap(b.buf) + } + } + cap1 := cap(b.buf) + // (*Buffer).grow allows for 2x capacity slop before sliding, + // so set our error threshold at 3x. + if cap1 > cap0*3 { + t.Errorf("buffer cap = %d; too big (grew from %d)", cap1, cap0) + } +} + +// BenchmarkBufferNotEmptyWriteRead: From Issue 5154. +func BenchmarkBufferNotEmptyWriteRead(b *testing.B) { + buf := make([]byte, 1024) + for i := 0; i < b.N; i++ { + var b Buffer + b.Write(buf[0:1]) + for i := 0; i < 5<<10; i++ { + b.Write(buf) + b.Read(buf) + } + } +} + +// BenchmarkBufferFullSmallReads checks that we don't compact too often. From Issue 5154. +func BenchmarkBufferFullSmallReads(b *testing.B) { + buf := make([]byte, 1024) + for i := 0; i < b.N; i++ { + var b Buffer + b.Write(buf) + for b.Len()+20 < cap(b.buf) { + b.Write(buf[:10]) + } + for i := 0; i < 5<<10; i++ { + b.Read(buf[:1]) + b.Write(buf[:1]) + } + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/crc32.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/crc32.go new file mode 100644 index 0000000..631c9d6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/crc32.go @@ -0,0 +1,30 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "hash/crc32" +) + +var table = crc32.MakeTable(crc32.Castagnoli) + +// CRC is a CRC-32 checksum computed using Castagnoli's polynomial. +type CRC uint32 + +// NewCRC creates a new crc based on the given bytes. +func NewCRC(b []byte) CRC { + return CRC(0).Update(b) +} + +// Update updates the crc with the given bytes. +func (c CRC) Update(b []byte) CRC { + return CRC(crc32.Update(uint32(c), table, b)) +} + +// Value returns a masked crc. +func (c CRC) Value() uint32 { + return uint32(c>>15|c<<17) + 0xa282ead8 +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/hash.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/hash.go new file mode 100644 index 0000000..7f3fa4e --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/hash.go @@ -0,0 +1,48 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "encoding/binary" +) + +// Hash return hash of the given data. +func Hash(data []byte, seed uint32) uint32 { + // Similar to murmur hash + const ( + m = uint32(0xc6a4a793) + r = uint32(24) + ) + var ( + h = seed ^ (uint32(len(data)) * m) + i int + ) + + for n := len(data) - len(data)%4; i < n; i += 4 { + h += binary.LittleEndian.Uint32(data[i:]) + h *= m + h ^= (h >> 16) + } + + switch len(data) - i { + default: + panic("not reached") + case 3: + h += uint32(data[i+2]) << 16 + fallthrough + case 2: + h += uint32(data[i+1]) << 8 + fallthrough + case 1: + h += uint32(data[i]) + h *= m + h ^= (h >> r) + case 0: + } + + return h +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/hash_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/hash_test.go new file mode 100644 index 0000000..a35d273 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/hash_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "testing" +) + +var hashTests = []struct { + data []byte + seed uint32 + hash uint32 +}{ + {nil, 0xbc9f1d34, 0xbc9f1d34}, + {[]byte{0x62}, 0xbc9f1d34, 0xef1345c4}, + {[]byte{0xc3, 0x97}, 0xbc9f1d34, 0x5b663814}, + {[]byte{0xe2, 0x99, 0xa5}, 0xbc9f1d34, 0x323c078f}, + {[]byte{0xe1, 0x80, 0xb9, 0x32}, 0xbc9f1d34, 0xed21633a}, + {[]byte{ + 0x01, 0xc0, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, + 0x00, 0x00, 0x00, 0x18, + 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, 0x12345678, 0xf333dabb}, +} + +func TestHash(t *testing.T) { + for i, x := range hashTests { + h := Hash(x.data, x.seed) + if h != x.hash { + t.Fatalf("test-%d: invalid hash, %#x vs %#x", i, h, x.hash) + } + } +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/range.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/range.go new file mode 100644 index 0000000..8515958 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/range.go @@ -0,0 +1,32 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +// Range is a key range. +type Range struct { + // Start of the key range, include in the range. + Start []byte + + // Limit of the key range, not include in the range. + Limit []byte +} + +// BytesPrefix returns key range that satisfy the given prefix. +// This only applicable for the standard 'bytes comparer'. +func BytesPrefix(prefix []byte) *Range { + var limit []byte + for i := len(prefix) - 1; i >= 0; i-- { + c := prefix[i] + if c < 0xff { + limit = make([]byte, i+1) + copy(limit, prefix) + limit[i] = c + 1 + break + } + } + return &Range{prefix, limit} +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/util/util.go b/vendor/github.com/syndtr/goleveldb/leveldb/util/util.go new file mode 100644 index 0000000..80614af --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/util/util.go @@ -0,0 +1,73 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package util provides utilities used throughout leveldb. +package util + +import ( + "errors" +) + +var ( + ErrReleased = errors.New("leveldb: resource already relesed") + ErrHasReleaser = errors.New("leveldb: releaser already defined") +) + +// Releaser is the interface that wraps the basic Release method. +type Releaser interface { + // Release releases associated resources. Release should always success + // and can be called multiple times without causing error. + Release() +} + +// ReleaseSetter is the interface that wraps the basic SetReleaser method. +type ReleaseSetter interface { + // SetReleaser associates the given releaser to the resources. The + // releaser will be called once coresponding resources released. + // Calling SetReleaser with nil will clear the releaser. + // + // This will panic if a releaser already present or coresponding + // resource is already released. Releaser should be cleared first + // before assigned a new one. + SetReleaser(releaser Releaser) +} + +// BasicReleaser provides basic implementation of Releaser and ReleaseSetter. +type BasicReleaser struct { + releaser Releaser + released bool +} + +// Released returns whether Release method already called. +func (r *BasicReleaser) Released() bool { + return r.released +} + +// Release implements Releaser.Release. +func (r *BasicReleaser) Release() { + if !r.released { + if r.releaser != nil { + r.releaser.Release() + r.releaser = nil + } + r.released = true + } +} + +// SetReleaser implements ReleaseSetter.SetReleaser. +func (r *BasicReleaser) SetReleaser(releaser Releaser) { + if r.released { + panic(ErrReleased) + } + if r.releaser != nil && releaser != nil { + panic(ErrHasReleaser) + } + r.releaser = releaser +} + +type NoopReleaser struct{} + +func (NoopReleaser) Release() {} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/version.go b/vendor/github.com/syndtr/goleveldb/leveldb/version.go new file mode 100644 index 0000000..9535e35 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/version.go @@ -0,0 +1,573 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sync/atomic" + "time" + "unsafe" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type tSet struct { + level int + table *tFile +} + +type version struct { + id int64 // unique monotonous increasing version id + s *session + + levels []tFiles + + // Level that should be compacted next and its compaction score. + // Score < 1 means compaction is not strictly needed. These fields + // are initialized by computeCompaction() + cLevel int + cScore float64 + + cSeek unsafe.Pointer + + closing bool + ref int + released bool +} + +// newVersion creates a new version with an unique monotonous increasing id. +func newVersion(s *session) *version { + id := atomic.AddInt64(&s.ntVersionId, 1) + nv := &version{s: s, id: id - 1} + return nv +} + +func (v *version) incref() { + if v.released { + panic("already released") + } + + v.ref++ + if v.ref == 1 { + select { + case v.s.refCh <- &vTask{vid: v.id, files: v.levels, created: time.Now()}: + // We can use v.levels directly here since it is immutable. + case <-v.s.closeC: + v.s.log("reference loop already exist") + } + } +} + +func (v *version) releaseNB() { + v.ref-- + if v.ref > 0 { + return + } else if v.ref < 0 { + panic("negative version ref") + } + select { + case v.s.relCh <- &vTask{vid: v.id, files: v.levels, created: time.Now()}: + // We can use v.levels directly here since it is immutable. + case <-v.s.closeC: + v.s.log("reference loop already exist") + } + + v.released = true +} + +func (v *version) release() { + v.s.vmu.Lock() + v.releaseNB() + v.s.vmu.Unlock() +} + +func (v *version) walkOverlapping(aux tFiles, ikey internalKey, f func(level int, t *tFile) bool, lf func(level int) bool) { + ukey := ikey.ukey() + + // Aux level. + if aux != nil { + for _, t := range aux { + if t.overlaps(v.s.icmp, ukey, ukey) { + if !f(-1, t) { + return + } + } + } + + if lf != nil && !lf(-1) { + return + } + } + + // Walk tables level-by-level. + for level, tables := range v.levels { + if len(tables) == 0 { + continue + } + + if level == 0 { + // Level-0 files may overlap each other. Find all files that + // overlap ukey. + for _, t := range tables { + if t.overlaps(v.s.icmp, ukey, ukey) { + if !f(level, t) { + return + } + } + } + } else { + if i := tables.searchMax(v.s.icmp, ikey); i < len(tables) { + t := tables[i] + if v.s.icmp.uCompare(ukey, t.imin.ukey()) >= 0 { + if !f(level, t) { + return + } + } + } + } + + if lf != nil && !lf(level) { + return + } + } +} + +func (v *version) get(aux tFiles, ikey internalKey, ro *opt.ReadOptions, noValue bool) (value []byte, tcomp bool, err error) { + if v.closing { + return nil, false, ErrClosed + } + + ukey := ikey.ukey() + sampleSeeks := !v.s.o.GetDisableSeeksCompaction() + + var ( + tset *tSet + tseek bool + + // Level-0. + zfound bool + zseq uint64 + zkt keyType + zval []byte + ) + + err = ErrNotFound + + // Since entries never hop across level, finding key/value + // in smaller level make later levels irrelevant. + v.walkOverlapping(aux, ikey, func(level int, t *tFile) bool { + if sampleSeeks && level >= 0 && !tseek { + if tset == nil { + tset = &tSet{level, t} + } else { + tseek = true + } + } + + var ( + fikey, fval []byte + ferr error + ) + if noValue { + fikey, ferr = v.s.tops.findKey(t, ikey, ro) + } else { + fikey, fval, ferr = v.s.tops.find(t, ikey, ro) + } + + switch ferr { + case nil: + case ErrNotFound: + return true + default: + err = ferr + return false + } + + if fukey, fseq, fkt, fkerr := parseInternalKey(fikey); fkerr == nil { + if v.s.icmp.uCompare(ukey, fukey) == 0 { + // Level <= 0 may overlaps each-other. + if level <= 0 { + if fseq >= zseq { + zfound = true + zseq = fseq + zkt = fkt + zval = fval + } + } else { + switch fkt { + case keyTypeVal: + value = fval + err = nil + case keyTypeDel: + default: + panic("leveldb: invalid internalKey type") + } + return false + } + } + } else { + err = fkerr + return false + } + + return true + }, func(level int) bool { + if zfound { + switch zkt { + case keyTypeVal: + value = zval + err = nil + case keyTypeDel: + default: + panic("leveldb: invalid internalKey type") + } + return false + } + + return true + }) + + if tseek && tset.table.consumeSeek() <= 0 { + tcomp = atomic.CompareAndSwapPointer(&v.cSeek, nil, unsafe.Pointer(tset)) + } + + return +} + +func (v *version) sampleSeek(ikey internalKey) (tcomp bool) { + var tset *tSet + + v.walkOverlapping(nil, ikey, func(level int, t *tFile) bool { + if tset == nil { + tset = &tSet{level, t} + return true + } + if tset.table.consumeSeek() <= 0 { + tcomp = atomic.CompareAndSwapPointer(&v.cSeek, nil, unsafe.Pointer(tset)) + } + return false + }, nil) + + return +} + +func (v *version) getIterators(slice *util.Range, ro *opt.ReadOptions) (its []iterator.Iterator) { + strict := opt.GetStrict(v.s.o.Options, ro, opt.StrictReader) + for level, tables := range v.levels { + if level == 0 { + // Merge all level zero files together since they may overlap. + for _, t := range tables { + its = append(its, v.s.tops.newIterator(t, slice, ro)) + } + } else if len(tables) != 0 { + its = append(its, iterator.NewIndexedIterator(tables.newIndexIterator(v.s.tops, v.s.icmp, slice, ro), strict)) + } + } + return +} + +func (v *version) newStaging() *versionStaging { + return &versionStaging{base: v} +} + +// Spawn a new version based on this version. +func (v *version) spawn(r *sessionRecord, trivial bool) *version { + staging := v.newStaging() + staging.commit(r) + return staging.finish(trivial) +} + +func (v *version) fillRecord(r *sessionRecord) { + for level, tables := range v.levels { + for _, t := range tables { + r.addTableFile(level, t) + } + } +} + +func (v *version) tLen(level int) int { + if level < len(v.levels) { + return len(v.levels[level]) + } + return 0 +} + +func (v *version) offsetOf(ikey internalKey) (n int64, err error) { + for level, tables := range v.levels { + for _, t := range tables { + if v.s.icmp.Compare(t.imax, ikey) <= 0 { + // Entire file is before "ikey", so just add the file size + n += t.size + } else if v.s.icmp.Compare(t.imin, ikey) > 0 { + // Entire file is after "ikey", so ignore + if level > 0 { + // Files other than level 0 are sorted by meta->min, so + // no further files in this level will contain data for + // "ikey". + break + } + } else { + // "ikey" falls in the range for this table. Add the + // approximate offset of "ikey" within the table. + if m, err := v.s.tops.offsetOf(t, ikey); err == nil { + n += m + } else { + return 0, err + } + } + } + } + + return +} + +func (v *version) pickMemdbLevel(umin, umax []byte, maxLevel int) (level int) { + if maxLevel > 0 { + if len(v.levels) == 0 { + return maxLevel + } + if !v.levels[0].overlaps(v.s.icmp, umin, umax, true) { + var overlaps tFiles + for ; level < maxLevel; level++ { + if pLevel := level + 1; pLevel >= len(v.levels) { + return maxLevel + } else if v.levels[pLevel].overlaps(v.s.icmp, umin, umax, false) { + break + } + if gpLevel := level + 2; gpLevel < len(v.levels) { + overlaps = v.levels[gpLevel].getOverlaps(overlaps, v.s.icmp, umin, umax, false) + if overlaps.size() > int64(v.s.o.GetCompactionGPOverlaps(level)) { + break + } + } + } + } + } + return +} + +func (v *version) computeCompaction() { + // Precomputed best level for next compaction + bestLevel := int(-1) + bestScore := float64(-1) + + statFiles := make([]int, len(v.levels)) + statSizes := make([]string, len(v.levels)) + statScore := make([]string, len(v.levels)) + statTotSize := int64(0) + + for level, tables := range v.levels { + var score float64 + size := tables.size() + if level == 0 { + // We treat level-0 specially by bounding the number of files + // instead of number of bytes for two reasons: + // + // (1) With larger write-buffer sizes, it is nice not to do too + // many level-0 compaction. + // + // (2) The files in level-0 are merged on every read and + // therefore we wish to avoid too many files when the individual + // file size is small (perhaps because of a small write-buffer + // setting, or very high compression ratios, or lots of + // overwrites/deletions). + score = float64(len(tables)) / float64(v.s.o.GetCompactionL0Trigger()) + } else { + score = float64(size) / float64(v.s.o.GetCompactionTotalSize(level)) + } + + if score > bestScore { + bestLevel = level + bestScore = score + } + + statFiles[level] = len(tables) + statSizes[level] = shortenb(int(size)) + statScore[level] = fmt.Sprintf("%.2f", score) + statTotSize += size + } + + v.cLevel = bestLevel + v.cScore = bestScore + + v.s.logf("version@stat F·%v S·%s%v Sc·%v", statFiles, shortenb(int(statTotSize)), statSizes, statScore) +} + +func (v *version) needCompaction() bool { + return v.cScore >= 1 || atomic.LoadPointer(&v.cSeek) != nil +} + +type tablesScratch struct { + added map[int64]atRecord + deleted map[int64]struct{} +} + +type versionStaging struct { + base *version + levels []tablesScratch +} + +func (p *versionStaging) getScratch(level int) *tablesScratch { + if level >= len(p.levels) { + newLevels := make([]tablesScratch, level+1) + copy(newLevels, p.levels) + p.levels = newLevels + } + return &(p.levels[level]) +} + +func (p *versionStaging) commit(r *sessionRecord) { + // Deleted tables. + for _, r := range r.deletedTables { + scratch := p.getScratch(r.level) + if r.level < len(p.base.levels) && len(p.base.levels[r.level]) > 0 { + if scratch.deleted == nil { + scratch.deleted = make(map[int64]struct{}) + } + scratch.deleted[r.num] = struct{}{} + } + if scratch.added != nil { + delete(scratch.added, r.num) + } + } + + // New tables. + for _, r := range r.addedTables { + scratch := p.getScratch(r.level) + if scratch.added == nil { + scratch.added = make(map[int64]atRecord) + } + scratch.added[r.num] = r + if scratch.deleted != nil { + delete(scratch.deleted, r.num) + } + } +} + +func (p *versionStaging) finish(trivial bool) *version { + // Build new version. + nv := newVersion(p.base.s) + numLevel := len(p.levels) + if len(p.base.levels) > numLevel { + numLevel = len(p.base.levels) + } + nv.levels = make([]tFiles, numLevel) + for level := 0; level < numLevel; level++ { + var baseTabels tFiles + if level < len(p.base.levels) { + baseTabels = p.base.levels[level] + } + + if level < len(p.levels) { + scratch := p.levels[level] + + // Short circuit if there is no change at all. + if len(scratch.added) == 0 && len(scratch.deleted) == 0 { + nv.levels[level] = baseTabels + continue + } + + var nt tFiles + // Prealloc list if possible. + if n := len(baseTabels) + len(scratch.added) - len(scratch.deleted); n > 0 { + nt = make(tFiles, 0, n) + } + + // Base tables. + for _, t := range baseTabels { + if _, ok := scratch.deleted[t.fd.Num]; ok { + continue + } + if _, ok := scratch.added[t.fd.Num]; ok { + continue + } + nt = append(nt, t) + } + + // Avoid resort if only files in this level are deleted + if len(scratch.added) == 0 { + nv.levels[level] = nt + continue + } + + // For normal table compaction, one compaction will only involve two levels + // of files. And the new files generated after merging the source level and + // source+1 level related files can be inserted as a whole into source+1 level + // without any overlap with the other source+1 files. + // + // When the amount of data maintained by leveldb is large, the number of files + // per level will be very large. While qsort is very inefficient for sorting + // already ordered arrays. Therefore, for the normal table compaction, we use + // binary search here to find the insert index to insert a batch of new added + // files directly instead of using qsort. + if trivial && len(scratch.added) > 0 { + added := make(tFiles, 0, len(scratch.added)) + for _, r := range scratch.added { + added = append(added, tableFileFromRecord(r)) + } + if level == 0 { + added.sortByNum() + index := nt.searchNumLess(added[len(added)-1].fd.Num) + nt = append(nt[:index], append(added, nt[index:]...)...) + } else { + added.sortByKey(p.base.s.icmp) + _, amax := added.getRange(p.base.s.icmp) + index := nt.searchMin(p.base.s.icmp, amax) + nt = append(nt[:index], append(added, nt[index:]...)...) + } + nv.levels[level] = nt + continue + } + + // New tables. + for _, r := range scratch.added { + nt = append(nt, tableFileFromRecord(r)) + } + + if len(nt) != 0 { + // Sort tables. + if level == 0 { + nt.sortByNum() + } else { + nt.sortByKey(p.base.s.icmp) + } + + nv.levels[level] = nt + } + } else { + nv.levels[level] = baseTabels + } + } + + // Trim levels. + n := len(nv.levels) + for ; n > 0 && nv.levels[n-1] == nil; n-- { + } + nv.levels = nv.levels[:n] + + // Compute compaction score for new version. + nv.computeCompaction() + + return nv +} + +type versionReleaser struct { + v *version + once bool +} + +func (vr *versionReleaser) Release() { + v := vr.v + v.s.vmu.Lock() + if !vr.once { + v.releaseNB() + vr.once = true + } + v.s.vmu.Unlock() +} diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/version_test.go b/vendor/github.com/syndtr/goleveldb/leveldb/version_test.go new file mode 100644 index 0000000..43e5b1f --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/leveldb/version_test.go @@ -0,0 +1,431 @@ +package leveldb + +import ( + "encoding/binary" + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "github.com/onsi/gomega" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +type testFileRec struct { + level int + num int64 +} + +func TestVersionStaging(t *testing.T) { + gomega.RegisterTestingT(t) + stor := testutil.NewStorage() + defer stor.Close() + s, err := newSession(stor, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + s.close() + s.release() + }() + + v := newVersion(s) + v.newStaging() + + tmp := make([]byte, 4) + mik := func(i uint64) []byte { + binary.BigEndian.PutUint32(tmp, uint32(i)) + return []byte(makeInternalKey(nil, tmp, 0, keyTypeVal)) + } + + for i, x := range []struct { + add, del []testFileRec + trivial bool + levels [][]int64 + }{ + { + add: []testFileRec{ + {1, 1}, + }, + levels: [][]int64{ + {}, + {1}, + }, + }, + { + add: []testFileRec{ + {1, 1}, + }, + levels: [][]int64{ + {}, + {1}, + }, + }, + { + del: []testFileRec{ + {1, 1}, + }, + levels: [][]int64{}, + }, + { + add: []testFileRec{ + {0, 1}, + {0, 3}, + {0, 2}, + {2, 5}, + {1, 4}, + }, + levels: [][]int64{ + {3, 2, 1}, + {4}, + {5}, + }, + }, + { + add: []testFileRec{ + {1, 6}, + {2, 5}, + }, + del: []testFileRec{ + {0, 1}, + {0, 4}, + }, + levels: [][]int64{ + {3, 2}, + {4, 6}, + {5}, + }, + }, + { + del: []testFileRec{ + {0, 3}, + {0, 2}, + {1, 4}, + {1, 6}, + {2, 5}, + }, + levels: [][]int64{}, + }, + { + add: []testFileRec{ + {0, 1}, + }, + levels: [][]int64{ + {1}, + }, + }, + { + add: []testFileRec{ + {1, 2}, + }, + levels: [][]int64{ + {1}, + {2}, + }, + }, + { + add: []testFileRec{ + {0, 3}, + }, + levels: [][]int64{ + {3, 1}, + {2}, + }, + }, + { + add: []testFileRec{ + {6, 9}, + }, + levels: [][]int64{ + {3, 1}, + {2}, + {}, + {}, + {}, + {}, + {9}, + }, + }, + { + del: []testFileRec{ + {6, 9}, + }, + levels: [][]int64{ + {3, 1}, + {2}, + }, + }, + // memory compaction + { + add: []testFileRec{ + {0, 5}, + }, + trivial: true, + levels: [][]int64{ + {5, 3, 1}, + {2}, + }, + }, + // memory compaction + { + add: []testFileRec{ + {0, 4}, + }, + trivial: true, + levels: [][]int64{ + {5, 4, 3, 1}, + {2}, + }, + }, + // table compaction + { + add: []testFileRec{ + {1, 6}, + {1, 7}, + {1, 8}, + }, + del: []testFileRec{ + {0, 3}, + {0, 4}, + {0, 5}, + }, + trivial: true, + levels: [][]int64{ + {1}, + {2, 6, 7, 8}, + }, + }, + } { + rec := &sessionRecord{} + for _, f := range x.add { + ik := mik(uint64(f.num)) + rec.addTable(f.level, f.num, 1, ik, ik) + } + for _, f := range x.del { + rec.delTable(f.level, f.num) + } + vs := v.newStaging() + vs.commit(rec) + v = vs.finish(x.trivial) + if len(v.levels) != len(x.levels) { + t.Fatalf("#%d: invalid level count: want=%d got=%d", i, len(x.levels), len(v.levels)) + } + for j, want := range x.levels { + tables := v.levels[j] + if len(want) != len(tables) { + t.Fatalf("#%d.%d: invalid tables count: want=%d got=%d", i, j, len(want), len(tables)) + } + got := make([]int64, len(tables)) + for k, t := range tables { + got[k] = t.fd.Num + } + if !reflect.DeepEqual(want, got) { + t.Fatalf("#%d.%d: invalid tables: want=%v got=%v", i, j, want, got) + } + } + } +} + +func TestVersionReference(t *testing.T) { + gomega.RegisterTestingT(t) + stor := testutil.NewStorage() + defer stor.Close() + s, err := newSession(stor, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + s.close() + s.release() + }() + + tmp := make([]byte, 4) + mik := func(i uint64) []byte { + binary.BigEndian.PutUint32(tmp, uint32(i)) + return []byte(makeInternalKey(nil, tmp, 0, keyTypeVal)) + } + + // Test normal version task correctness + refc := make(chan map[int64]int) + + for i, x := range []struct { + add, del []testFileRec + expect map[int64]int + failed bool + }{ + { + []testFileRec{{0, 1}, {0, 2}}, + nil, + map[int64]int{1: 1, 2: 1}, + false, + }, + { + []testFileRec{{0, 3}, {0, 4}}, + []testFileRec{{0, 1}}, + map[int64]int{2: 1, 3: 1, 4: 1}, + false, + }, + { + []testFileRec{{0, 1}, {0, 5}, {0, 6}, {0, 7}}, + []testFileRec{{0, 2}, {0, 3}, {0, 4}}, + map[int64]int{1: 1, 5: 1, 6: 1, 7: 1}, + false, + }, + { + nil, + nil, + map[int64]int{1: 1, 5: 1, 6: 1, 7: 1}, + true, + }, + { + []testFileRec{{0, 1}, {0, 5}, {0, 6}, {0, 7}}, + nil, + map[int64]int{1: 2, 5: 2, 6: 2, 7: 2}, + false, + }, + { + nil, + []testFileRec{{0, 1}, {0, 5}, {0, 6}, {0, 7}}, + map[int64]int{1: 1, 5: 1, 6: 1, 7: 1}, + false, + }, + { + []testFileRec{{0, 0}}, + []testFileRec{{0, 1}, {0, 5}, {0, 6}, {0, 7}}, + map[int64]int{0: 1}, + false, + }, + } { + rec := &sessionRecord{} + for n, f := range x.add { + rec.addTable(f.level, f.num, 1, mik(uint64(i+n)), mik(uint64(i+n))) + } + for _, f := range x.del { + rec.delTable(f.level, f.num) + } + + // Simulate some read operations + var wg sync.WaitGroup + readN := rand.Intn(300) + for i := 0; i < readN; i++ { + wg.Add(1) + go func() { + v := s.version() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(300))) + v.release() + wg.Done() + }() + } + + v := s.version() + vs := v.newStaging() + vs.commit(rec) + nv := vs.finish(false) + + if x.failed { + s.abandon <- nv.id + } else { + s.setVersion(rec, nv) + } + v.release() + + // Wait all read operations + wg.Wait() + + time.Sleep(100 * time.Millisecond) // Wait lazy reference finish tasks + + s.fileRefCh <- refc + ref := <-refc + if !reflect.DeepEqual(ref, x.expect) { + t.Errorf("case %d failed, file reference mismatch, GOT %v, WANT %v", i, ref, x.expect) + } + } + + // Test version task overflow + var longV = s.version() // This version is held by some long-time operation + var exp = map[int64]int{0: 1, maxCachedNumber: 1} + for i := 1; i <= maxCachedNumber; i++ { + rec := &sessionRecord{} + rec.addTable(0, int64(i), 1, mik(uint64(i)), mik(uint64(i))) + rec.delTable(0, int64(i-1)) + v := s.version() + vs := v.newStaging() + vs.commit(rec) + nv := vs.finish(false) + s.setVersion(rec, nv) + v.release() + } + time.Sleep(100 * time.Millisecond) // Wait lazy reference finish tasks + + s.fileRefCh <- refc + ref := <-refc + if !reflect.DeepEqual(exp, ref) { + t.Errorf("file reference mismatch, GOT %v, WANT %v", ref, exp) + } + + longV.release() + s.fileRefCh <- refc + ref = <-refc + delete(exp, 0) + if !reflect.DeepEqual(exp, ref) { + t.Errorf("file reference mismatch, GOT %v, WANT %v", ref, exp) + } +} + +func BenchmarkVersionStagingNonTrivial(b *testing.B) { + benchmarkVersionStaging(b, false, 100000) +} + +func BenchmarkVersionStagingTrivial(b *testing.B) { + benchmarkVersionStaging(b, true, 100000) +} + +func benchmarkVersionStaging(b *testing.B, trivial bool, size int) { + stor := storage.NewMemStorage() + defer stor.Close() + s, err := newSession(stor, nil) + if err != nil { + b.Fatal(err) + } + defer func() { + s.close() + s.release() + }() + + tmp := make([]byte, 4) + mik := func(i uint64) []byte { + binary.BigEndian.PutUint32(tmp, uint32(i)) + return []byte(makeInternalKey(nil, tmp, 0, keyTypeVal)) + } + + rec := &sessionRecord{} + for i := 0; i < size; i++ { + ik := mik(uint64(i)) + rec.addTable(1, int64(i), 1, ik, ik) + } + + v := newVersion(s) + vs := v.newStaging() + vs.commit(rec) + v = vs.finish(false) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + rec := &sessionRecord{} + index := rand.Intn(size) + ik := mik(uint64(index)) + + cnt := 0 + for j := index; j < size && cnt <= 3; j++ { + rec.addTable(1, int64(i), 1, ik, ik) + cnt += 1 + } + vs := v.newStaging() + vs.commit(rec) + vs.finish(trivial) + } +} diff --git a/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/key.go b/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/key.go new file mode 100644 index 0000000..c9f6963 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/key.go @@ -0,0 +1,137 @@ +package main + +import ( + "encoding/binary" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +type ErrIkeyCorrupted struct { + Ikey []byte + Reason string +} + +func (e *ErrIkeyCorrupted) Error() string { + return fmt.Sprintf("leveldb: iKey %q corrupted: %s", e.Ikey, e.Reason) +} + +func newErrIkeyCorrupted(ikey []byte, reason string) error { + return errors.NewErrCorrupted(storage.FileDesc{}, &ErrIkeyCorrupted{append([]byte{}, ikey...), reason}) +} + +type kType int + +func (kt kType) String() string { + switch kt { + case ktDel: + return "d" + case ktVal: + return "v" + } + return "x" +} + +// Value types encoded as the last component of internal keys. +// Don't modify; this value are saved to disk. +const ( + ktDel kType = iota + ktVal +) + +// ktSeek defines the kType that should be passed when constructing an +// internal key for seeking to a particular sequence number (since we +// sort sequence numbers in decreasing order and the value type is +// embedded as the low 8 bits in the sequence number in internal keys, +// we need to use the highest-numbered ValueType, not the lowest). +const ktSeek = ktVal + +const ( + // Maximum value possible for sequence number; the 8-bits are + // used by value type, so its can packed together in single + // 64-bit integer. + kMaxSeq uint64 = (uint64(1) << 56) - 1 + // Maximum value possible for packed sequence number and type. + kMaxNum uint64 = (kMaxSeq << 8) | uint64(ktSeek) +) + +// Maximum number encoded in bytes. +var kMaxNumBytes = make([]byte, 8) + +func init() { + binary.LittleEndian.PutUint64(kMaxNumBytes, kMaxNum) +} + +type iKey []byte + +func newIkey(ukey []byte, seq uint64, kt kType) iKey { + if seq > kMaxSeq { + panic("leveldb: invalid sequence number") + } else if kt > ktVal { + panic("leveldb: invalid type") + } + + ik := make(iKey, len(ukey)+8) + copy(ik, ukey) + binary.LittleEndian.PutUint64(ik[len(ukey):], (seq<<8)|uint64(kt)) + return ik +} + +func parseIkey(ik []byte) (ukey []byte, seq uint64, kt kType, err error) { + if len(ik) < 8 { + return nil, 0, 0, newErrIkeyCorrupted(ik, "invalid length") + } + num := binary.LittleEndian.Uint64(ik[len(ik)-8:]) + seq, kt = uint64(num>>8), kType(num&0xff) + if kt > ktVal { + return nil, 0, 0, newErrIkeyCorrupted(ik, "invalid type") + } + ukey = ik[:len(ik)-8] + return +} + +func validIkey(ik []byte) bool { + _, _, _, err := parseIkey(ik) + return err == nil +} + +func (ik iKey) assert() { + if ik == nil { + panic("leveldb: nil iKey") + } + if len(ik) < 8 { + panic(fmt.Sprintf("leveldb: iKey %q, len=%d: invalid length", ik, len(ik))) + } +} + +func (ik iKey) ukey() []byte { + ik.assert() + return ik[:len(ik)-8] +} + +func (ik iKey) num() uint64 { + ik.assert() + return binary.LittleEndian.Uint64(ik[len(ik)-8:]) +} + +func (ik iKey) parseNum() (seq uint64, kt kType) { + num := ik.num() + seq, kt = uint64(num>>8), kType(num&0xff) + if kt > ktVal { + panic(fmt.Sprintf("leveldb: iKey %q, len=%d: invalid type %#x", ik, len(ik), kt)) + } + return +} + +func (ik iKey) String() string { + if ik == nil { + return "" + } + + if ukey, seq, kt, err := parseIkey(ik); err == nil { + return fmt.Sprintf("%x,%s%d", ukey, kt, seq) + } else { + return "" + } +} diff --git a/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/main.go b/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/main.go new file mode 100644 index 0000000..d4bd6d6 --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/manualtest/dbstress/main.go @@ -0,0 +1,630 @@ +package main + +import ( + "crypto/rand" + "encoding/binary" + "flag" + "fmt" + "log" + mrand "math/rand" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "path" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/table" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + dbPath = path.Join(os.TempDir(), "goleveldb-testdb") + openFilesCacheCapacity = 500 + keyLen = 63 + valueLen = 256 + numKeys = arrayInt{100000, 1332, 531, 1234, 9553, 1024, 35743} + httpProf = "127.0.0.1:5454" + transactionProb = 0.5 + enableBlockCache = false + enableCompression = false + enableBufferPool = false + + wg = new(sync.WaitGroup) + done, fail uint32 + + bpool *util.BufferPool +) + +type arrayInt []int + +func (a arrayInt) String() string { + var str string + for i, n := range a { + if i > 0 { + str += "," + } + str += strconv.Itoa(n) + } + return str +} + +func (a *arrayInt) Set(str string) error { + var na arrayInt + for _, s := range strings.Split(str, ",") { + s = strings.TrimSpace(s) + if s != "" { + n, err := strconv.Atoi(s) + if err != nil { + return err + } + na = append(na, n) + } + } + *a = na + return nil +} + +func init() { + flag.StringVar(&dbPath, "db", dbPath, "testdb path") + flag.IntVar(&openFilesCacheCapacity, "openfilescachecap", openFilesCacheCapacity, "open files cache capacity") + flag.IntVar(&keyLen, "keylen", keyLen, "key length") + flag.IntVar(&valueLen, "valuelen", valueLen, "value length") + flag.Var(&numKeys, "numkeys", "num keys") + flag.StringVar(&httpProf, "httpprof", httpProf, "http pprof listen addr") + flag.Float64Var(&transactionProb, "transactionprob", transactionProb, "probablity of writes using transaction") + flag.BoolVar(&enableBufferPool, "enablebufferpool", enableBufferPool, "enable buffer pool") + flag.BoolVar(&enableBlockCache, "enableblockcache", enableBlockCache, "enable block cache") + flag.BoolVar(&enableCompression, "enablecompression", enableCompression, "enable block compression") +} + +func randomData(dst []byte, ns, prefix byte, i uint32, dataLen int) []byte { + if dataLen < (2+4+4)*2+4 { + panic("dataLen is too small") + } + if cap(dst) < dataLen { + dst = make([]byte, dataLen) + } else { + dst = dst[:dataLen] + } + half := (dataLen - 4) / 2 + if _, err := rand.Reader.Read(dst[2 : half-8]); err != nil { + panic(err) + } + dst[0] = ns + dst[1] = prefix + binary.LittleEndian.PutUint32(dst[half-8:], i) + binary.LittleEndian.PutUint32(dst[half-8:], i) + binary.LittleEndian.PutUint32(dst[half-4:], util.NewCRC(dst[:half-4]).Value()) + full := half * 2 + copy(dst[half:full], dst[:half]) + if full < dataLen-4 { + if _, err := rand.Reader.Read(dst[full : dataLen-4]); err != nil { + panic(err) + } + } + binary.LittleEndian.PutUint32(dst[dataLen-4:], util.NewCRC(dst[:dataLen-4]).Value()) + return dst +} + +func dataSplit(data []byte) (data0, data1 []byte) { + n := (len(data) - 4) / 2 + return data[:n], data[n : n+n] +} + +func dataNS(data []byte) byte { + return data[0] +} + +func dataPrefix(data []byte) byte { + return data[1] +} + +func dataI(data []byte) uint32 { + return binary.LittleEndian.Uint32(data[(len(data)-4)/2-8:]) +} + +func dataChecksum(data []byte) (uint32, uint32) { + checksum0 := binary.LittleEndian.Uint32(data[len(data)-4:]) + checksum1 := util.NewCRC(data[:len(data)-4]).Value() + return checksum0, checksum1 +} + +func dataPrefixSlice(ns, prefix byte) *util.Range { + return util.BytesPrefix([]byte{ns, prefix}) +} + +func dataNsSlice(ns byte) *util.Range { + return util.BytesPrefix([]byte{ns}) +} + +type testingStorage struct { + storage.Storage +} + +func (ts *testingStorage) scanTable(fd storage.FileDesc, checksum bool) (corrupted bool) { + r, err := ts.Open(fd) + if err != nil { + log.Fatal(err) + } + defer r.Close() + + size, err := r.Seek(0, os.SEEK_END) + if err != nil { + log.Fatal(err) + } + + o := &opt.Options{ + DisableLargeBatchTransaction: true, + Strict: opt.NoStrict, + } + if checksum { + o.Strict = opt.StrictBlockChecksum | opt.StrictReader + } + tr, err := table.NewReader(r, size, fd, nil, bpool, o) + if err != nil { + log.Fatal(err) + } + defer tr.Release() + + checkData := func(i int, t string, data []byte) bool { + if len(data) == 0 { + panic(fmt.Sprintf("[%v] nil data: i=%d t=%s", fd, i, t)) + } + + checksum0, checksum1 := dataChecksum(data) + if checksum0 != checksum1 { + atomic.StoreUint32(&fail, 1) + atomic.StoreUint32(&done, 1) + corrupted = true + + data0, data1 := dataSplit(data) + data0c0, data0c1 := dataChecksum(data0) + data1c0, data1c1 := dataChecksum(data1) + log.Printf("FATAL: [%v] Corrupted data i=%d t=%s (%#x != %#x): %x(%v) vs %x(%v)", + fd, i, t, checksum0, checksum1, data0, data0c0 == data0c1, data1, data1c0 == data1c1) + return true + } + return false + } + + iter := tr.NewIterator(nil, nil) + defer iter.Release() + for i := 0; iter.Next(); i++ { + ukey, _, kt, kerr := parseIkey(iter.Key()) + if kerr != nil { + atomic.StoreUint32(&fail, 1) + atomic.StoreUint32(&done, 1) + corrupted = true + + log.Printf("FATAL: [%v] Corrupted ikey i=%d: %v", fd, i, kerr) + return + } + if checkData(i, "key", ukey) { + return + } + if kt == ktVal && checkData(i, "value", iter.Value()) { + return + } + } + if err := iter.Error(); err != nil { + if errors.IsCorrupted(err) { + atomic.StoreUint32(&fail, 1) + atomic.StoreUint32(&done, 1) + corrupted = true + + log.Printf("FATAL: [%v] Corruption detected: %v", fd, err) + } else { + log.Fatal(err) + } + } + + return +} + +func (ts *testingStorage) Remove(fd storage.FileDesc) error { + if atomic.LoadUint32(&fail) == 1 { + return nil + } + + if fd.Type == storage.TypeTable { + if ts.scanTable(fd, true) { + return nil + } + } + return ts.Storage.Remove(fd) +} + +type latencyStats struct { + mark time.Time + dur, min, max time.Duration + num int +} + +func (s *latencyStats) start() { + s.mark = time.Now() +} + +func (s *latencyStats) record(n int) { + if s.mark.IsZero() { + panic("not started") + } + dur := time.Now().Sub(s.mark) + dur1 := dur / time.Duration(n) + if dur1 < s.min || s.min == 0 { + s.min = dur1 + } + if dur1 > s.max { + s.max = dur1 + } + s.dur += dur + s.num += n + s.mark = time.Time{} +} + +func (s *latencyStats) ratePerSec() int { + durSec := s.dur / time.Second + if durSec > 0 { + return s.num / int(durSec) + } + return s.num +} + +func (s *latencyStats) avg() time.Duration { + if s.num > 0 { + return s.dur / time.Duration(s.num) + } + return 0 +} + +func (s *latencyStats) add(x *latencyStats) { + if x.min < s.min || s.min == 0 { + s.min = x.min + } + if x.max > s.max { + s.max = x.max + } + s.dur += x.dur + s.num += x.num +} + +func main() { + flag.Parse() + + if enableBufferPool { + bpool = util.NewBufferPool(opt.DefaultBlockSize + 128) + } + + log.Printf("Test DB stored at %q", dbPath) + if httpProf != "" { + log.Printf("HTTP pprof listening at %q", httpProf) + runtime.SetBlockProfileRate(1) + go func() { + if err := http.ListenAndServe(httpProf, nil); err != nil { + log.Fatalf("HTTPPROF: %v", err) + } + }() + } + + runtime.GOMAXPROCS(runtime.NumCPU()) + + os.RemoveAll(dbPath) + stor, err := storage.OpenFile(dbPath, false) + if err != nil { + log.Fatal(err) + } + tstor := &testingStorage{stor} + defer tstor.Close() + + fatalf := func(err error, format string, v ...interface{}) { + atomic.StoreUint32(&fail, 1) + atomic.StoreUint32(&done, 1) + log.Printf("FATAL: "+format, v...) + if err != nil && errors.IsCorrupted(err) { + cerr := err.(*errors.ErrCorrupted) + if !cerr.Fd.Zero() && cerr.Fd.Type == storage.TypeTable { + log.Print("FATAL: corruption detected, scanning...") + if !tstor.scanTable(storage.FileDesc{Type: storage.TypeTable, Num: cerr.Fd.Num}, false) { + log.Printf("FATAL: unable to find corrupted key/value pair in table %v", cerr.Fd) + } + } + } + runtime.Goexit() + } + + if openFilesCacheCapacity == 0 { + openFilesCacheCapacity = -1 + } + o := &opt.Options{ + OpenFilesCacheCapacity: openFilesCacheCapacity, + DisableBufferPool: !enableBufferPool, + DisableBlockCache: !enableBlockCache, + ErrorIfExist: true, + Compression: opt.NoCompression, + } + if enableCompression { + o.Compression = opt.DefaultCompression + } + + db, err := leveldb.Open(tstor, o) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + var ( + mu = &sync.Mutex{} + gGetStat = &latencyStats{} + gIterStat = &latencyStats{} + gWriteStat = &latencyStats{} + gTrasactionStat = &latencyStats{} + startTime = time.Now() + + writeReq = make(chan *leveldb.Batch) + writeAck = make(chan error) + writeAckAck = make(chan struct{}) + ) + + go func() { + for b := range writeReq { + + var err error + if mrand.Float64() < transactionProb { + log.Print("> Write using transaction") + gTrasactionStat.start() + var tr *leveldb.Transaction + if tr, err = db.OpenTransaction(); err == nil { + if err = tr.Write(b, nil); err == nil { + if err = tr.Commit(); err == nil { + gTrasactionStat.record(b.Len()) + } + } else { + tr.Discard() + } + } + } else { + gWriteStat.start() + if err = db.Write(b, nil); err == nil { + gWriteStat.record(b.Len()) + } + } + writeAck <- err + <-writeAckAck + } + }() + + go func() { + for { + time.Sleep(3 * time.Second) + + log.Print("------------------------") + + log.Printf("> Elapsed=%v", time.Now().Sub(startTime)) + mu.Lock() + log.Printf("> GetLatencyMin=%v GetLatencyMax=%v GetLatencyAvg=%v GetRatePerSec=%d", + gGetStat.min, gGetStat.max, gGetStat.avg(), gGetStat.ratePerSec()) + log.Printf("> IterLatencyMin=%v IterLatencyMax=%v IterLatencyAvg=%v IterRatePerSec=%d", + gIterStat.min, gIterStat.max, gIterStat.avg(), gIterStat.ratePerSec()) + log.Printf("> WriteLatencyMin=%v WriteLatencyMax=%v WriteLatencyAvg=%v WriteRatePerSec=%d", + gWriteStat.min, gWriteStat.max, gWriteStat.avg(), gWriteStat.ratePerSec()) + log.Printf("> TransactionLatencyMin=%v TransactionLatencyMax=%v TransactionLatencyAvg=%v TransactionRatePerSec=%d", + gTrasactionStat.min, gTrasactionStat.max, gTrasactionStat.avg(), gTrasactionStat.ratePerSec()) + mu.Unlock() + + cachedblock, _ := db.GetProperty("leveldb.cachedblock") + openedtables, _ := db.GetProperty("leveldb.openedtables") + alivesnaps, _ := db.GetProperty("leveldb.alivesnaps") + aliveiters, _ := db.GetProperty("leveldb.aliveiters") + blockpool, _ := db.GetProperty("leveldb.blockpool") + writeDelay, _ := db.GetProperty("leveldb.writedelay") + ioStats, _ := db.GetProperty("leveldb.iostats") + compCount, _ := db.GetProperty("leveldb.compcount") + log.Printf("> BlockCache=%s OpenedTables=%s AliveSnaps=%s AliveIter=%s BlockPool=%q WriteDelay=%q IOStats=%q CompCount=%q", + cachedblock, openedtables, alivesnaps, aliveiters, blockpool, writeDelay, ioStats, compCount) + log.Print("------------------------") + } + }() + + for ns, numKey := range numKeys { + func(ns, numKey int) { + log.Printf("[%02d] STARTING: numKey=%d", ns, numKey) + + keys := make([][]byte, numKey) + for i := range keys { + keys[i] = randomData(nil, byte(ns), 1, uint32(i), keyLen) + } + + wg.Add(1) + go func() { + var wi uint32 + defer func() { + log.Printf("[%02d] WRITER DONE #%d", ns, wi) + wg.Done() + }() + + var ( + b = new(leveldb.Batch) + k2, v2 []byte + nReader int32 + ) + for atomic.LoadUint32(&done) == 0 { + log.Printf("[%02d] WRITER #%d", ns, wi) + + b.Reset() + for _, k1 := range keys { + k2 = randomData(k2, byte(ns), 2, wi, keyLen) + v2 = randomData(v2, byte(ns), 3, wi, valueLen) + b.Put(k2, v2) + b.Put(k1, k2) + } + writeReq <- b + if err := <-writeAck; err != nil { + writeAckAck <- struct{}{} + fatalf(err, "[%02d] WRITER #%d db.Write: %v", ns, wi, err) + } + + snap, err := db.GetSnapshot() + if err != nil { + writeAckAck <- struct{}{} + fatalf(err, "[%02d] WRITER #%d db.GetSnapshot: %v", ns, wi, err) + } + + writeAckAck <- struct{}{} + + wg.Add(1) + atomic.AddInt32(&nReader, 1) + go func(snapwi uint32, snap *leveldb.Snapshot) { + var ( + ri int + iterStat = &latencyStats{} + getStat = &latencyStats{} + ) + defer func() { + mu.Lock() + gGetStat.add(getStat) + gIterStat.add(iterStat) + mu.Unlock() + + atomic.AddInt32(&nReader, -1) + log.Printf("[%02d] READER #%d.%d DONE Snap=%v Alive=%d IterLatency=%v GetLatency=%v", ns, snapwi, ri, snap, atomic.LoadInt32(&nReader), iterStat.avg(), getStat.avg()) + snap.Release() + wg.Done() + }() + + stopi := snapwi + 3 + for (ri < 3 || atomic.LoadUint32(&wi) < stopi) && atomic.LoadUint32(&done) == 0 { + var n int + iter := snap.NewIterator(dataPrefixSlice(byte(ns), 1), nil) + iterStat.start() + for iter.Next() { + k1 := iter.Key() + k2 := iter.Value() + iterStat.record(1) + + if dataNS(k2) != byte(ns) { + fatalf(nil, "[%02d] READER #%d.%d K%d invalid in-key NS: want=%d got=%d", ns, snapwi, ri, n, ns, dataNS(k2)) + } + + kwritei := dataI(k2) + if kwritei != snapwi { + fatalf(nil, "[%02d] READER #%d.%d K%d invalid in-key iter num: %d", ns, snapwi, ri, n, kwritei) + } + + getStat.start() + v2, err := snap.Get(k2, nil) + if err != nil { + fatalf(err, "[%02d] READER #%d.%d K%d snap.Get: %v\nk1: %x\n -> k2: %x", ns, snapwi, ri, n, err, k1, k2) + } + getStat.record(1) + + if checksum0, checksum1 := dataChecksum(v2); checksum0 != checksum1 { + err := &errors.ErrCorrupted{Fd: storage.FileDesc{Type: 0xff, Num: 0}, Err: fmt.Errorf("v2: %x: checksum mismatch: %v vs %v", v2, checksum0, checksum1)} + fatalf(err, "[%02d] READER #%d.%d K%d snap.Get: %v\nk1: %x\n -> k2: %x", ns, snapwi, ri, n, err, k1, k2) + } + + n++ + iterStat.start() + } + iter.Release() + if err := iter.Error(); err != nil { + fatalf(err, "[%02d] READER #%d.%d K%d iter.Error: %v", ns, snapwi, ri, numKey, err) + } + if n != numKey { + fatalf(nil, "[%02d] READER #%d.%d missing keys: want=%d got=%d", ns, snapwi, ri, numKey, n) + } + + ri++ + } + }(wi, snap) + + atomic.AddUint32(&wi, 1) + } + }() + + delB := new(leveldb.Batch) + wg.Add(1) + go func() { + var ( + i int + iterStat = &latencyStats{} + ) + defer func() { + log.Printf("[%02d] SCANNER DONE #%d", ns, i) + wg.Done() + }() + + time.Sleep(2 * time.Second) + + for atomic.LoadUint32(&done) == 0 { + var n int + delB.Reset() + iter := db.NewIterator(dataNsSlice(byte(ns)), nil) + iterStat.start() + for iter.Next() && atomic.LoadUint32(&done) == 0 { + k := iter.Key() + v := iter.Value() + iterStat.record(1) + + for ci, x := range [...][]byte{k, v} { + checksum0, checksum1 := dataChecksum(x) + if checksum0 != checksum1 { + if ci == 0 { + fatalf(nil, "[%02d] SCANNER %d.%d invalid key checksum: want %d, got %d\n%x -> %x", ns, i, n, checksum0, checksum1, k, v) + } else { + fatalf(nil, "[%02d] SCANNER %d.%d invalid value checksum: want %d, got %d\n%x -> %x", ns, i, n, checksum0, checksum1, k, v) + } + } + } + + if dataPrefix(k) == 2 || mrand.Int()%999 == 0 { + delB.Delete(k) + } + + n++ + iterStat.start() + } + iter.Release() + if err := iter.Error(); err != nil { + fatalf(err, "[%02d] SCANNER #%d.%d iter.Error: %v", ns, i, n, err) + } + + if n > 0 { + log.Printf("[%02d] SCANNER #%d IterLatency=%v", ns, i, iterStat.avg()) + } + + if delB.Len() > 0 && atomic.LoadUint32(&done) == 0 { + t := time.Now() + writeReq <- delB + if err := <-writeAck; err != nil { + writeAckAck <- struct{}{} + fatalf(err, "[%02d] SCANNER #%d db.Write: %v", ns, i, err) + } else { + writeAckAck <- struct{}{} + } + log.Printf("[%02d] SCANNER #%d Deleted=%d Time=%v", ns, i, delB.Len(), time.Now().Sub(t)) + } + + i++ + } + }() + }(ns, numKey) + } + + go func() { + sig := make(chan os.Signal) + signal.Notify(sig, os.Interrupt, os.Kill) + log.Printf("Got signal: %v, exiting...", <-sig) + atomic.StoreUint32(&done, 1) + }() + + wg.Wait() +} diff --git a/vendor/github.com/syndtr/goleveldb/manualtest/filelock/main.go b/vendor/github.com/syndtr/goleveldb/manualtest/filelock/main.go new file mode 100644 index 0000000..192951f --- /dev/null +++ b/vendor/github.com/syndtr/goleveldb/manualtest/filelock/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "bufio" + "bytes" + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/syndtr/goleveldb/leveldb/storage" +) + +var ( + filename string + child bool +) + +func init() { + flag.StringVar(&filename, "filename", filepath.Join(os.TempDir(), "goleveldb_filelock_test"), "Filename used for testing") + flag.BoolVar(&child, "child", false, "This is the child") +} + +func runChild() error { + var args []string + args = append(args, os.Args[1:]...) + args = append(args, "-child") + cmd := exec.Command(os.Args[0], args...) + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + r := bufio.NewReader(&out) + for { + line, _, e1 := r.ReadLine() + if e1 != nil { + break + } + fmt.Println("[Child]", string(line)) + } + return err +} + +func main() { + flag.Parse() + + fmt.Printf("Using path: %s\n", filename) + if child { + fmt.Println("Child flag set.") + } + + stor, err := storage.OpenFile(filename, false) + if err != nil { + fmt.Printf("Could not open storage: %s", err) + os.Exit(10) + } + + if !child { + fmt.Println("Executing child -- first test (expecting error)") + err := runChild() + if err == nil { + fmt.Println("Expecting error from child") + } else if err.Error() != "exit status 10" { + fmt.Println("Got unexpected error from child:", err) + } else { + fmt.Printf("Got error from child: %s (expected)\n", err) + } + } + + err = stor.Close() + if err != nil { + fmt.Printf("Error when closing storage: %s", err) + os.Exit(11) + } + + if !child { + fmt.Println("Executing child -- second test") + err := runChild() + if err != nil { + fmt.Println("Got unexpected error from child:", err) + } + } + + os.RemoveAll(filename) +} diff --git a/xorm.go b/xorm.go index c1f332f..169bbd5 100644 --- a/xorm.go +++ b/xorm.go @@ -8,112 +8,75 @@ package xorm import ( "context" - "fmt" "os" - "reflect" "runtime" - "sync" "time" - "github.com/xormplus/core" + "github.com/xormplus/xorm/caches" + "github.com/xormplus/xorm/core" + "github.com/xormplus/xorm/dialects" + "github.com/xormplus/xorm/log" + "github.com/xormplus/xorm/names" + "github.com/xormplus/xorm/schemas" + "github.com/xormplus/xorm/tags" ) const ( // Version show the xorm's version - Version string = "0.7.6.0327" + Version string = "0.8.0.0330" ) -func regDrvsNDialects() bool { - providedDrvsNDialects := map[string]struct { - dbType core.DbType - getDriver func() core.Driver - getDialect func() core.Dialect - }{ - "mssql": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, - "odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access - "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, - "pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, - "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, - "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, - "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, - } - - for driverName, v := range providedDrvsNDialects { - if driver := core.QueryDriver(driverName); driver == nil { - core.RegisterDriver(driverName, v.getDriver()) - core.RegisterDialect(v.dbType, v.getDialect) - } - } - return true -} - func close(engine *Engine) { engine.Close() } -func init() { - regDrvsNDialects() -} - // NewEngine new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - driver := core.QueryDriver(driverName) - if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) - } - - uri, err := driver.Parse(driverName, dataSourceName) + dialect, err := dialects.OpenDialect(driverName, dataSourceName) if err != nil { return nil, err } - dialect := core.QueryDialect(uri.DbType) - if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DbType) - } - db, err := core.Open(driverName, dataSourceName) if err != nil { return nil, err } - err = dialect.Init(db, uri, driverName, dataSourceName) - if err != nil { - return nil, err - } + cacherMgr := caches.NewManager() + mapper := names.NewCacheMapper(new(names.SnakeMapper)) + tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) engine := &Engine{ - db: db, dialect: dialect, - Tables: make(map[reflect.Type]*core.Table), - mutex: &sync.RWMutex{}, - TagIdentifier: "xorm", TZLocation: time.Local, - tagHandlers: defaultTagHandlers, - cachers: make(map[string]core.Cacher), defaultContext: context.Background(), + cacherMgr: cacherMgr, + tagParser: tagParser, + driverName: driverName, + dataSourceName: dataSourceName, + db: db, + logSessionID: false, } - if uri.DbType == core.SQLITE { + if dialect.URI().DBType == schemas.SQLITE { engine.DatabaseTZ = time.UTC } else { engine.DatabaseTZ = time.Local } - logger := NewSimpleLogger(os.Stdout) - logger.SetLevel(core.LOG_INFO) - engine.SetLogger(logger) - engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) + logger := log.NewSimpleLogger(os.Stdout) + logger.SetLevel(log.LOG_INFO) + engine.SetLogger(log.NewLoggerAdapter(logger)) - runtime.SetFinalizer(engine, close) + runtime.SetFinalizer(engine, func(engine *Engine) { + engine.Close() + }) return engine, nil } -// NewEngineWithParams new a db manager with params. The params will be passed to dialect. +// NewEngineWithParams new a db manager with params. The params will be passed to dialects. func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { engine, err := NewEngine(driverName, dataSourceName) engine.dialect.SetParams(params)