diff --git a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/Migration.kt b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/Migration.kt index ca8e60a..16d9720 100644 --- a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/Migration.kt +++ b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/Migration.kt @@ -34,13 +34,13 @@ class Migration( funcSpecBuilder.addModifiers(KModifier.OVERRIDE) funcSpecBuilder.addParameter(databaseArgName, state.sqLiteDatabaseType) - state.rules.getOnStartRule(scheme1.version, scheme2.version)?.also { + state.rules.getOnStartRules(scheme1.version, scheme2.version).forEach { funcSpecBuilder.addStatement("%L", it.getInvokeCode(databaseArgName, scheme1.version, scheme2.version)) } migrate() - state.rules.getOnEndRule(scheme1.version, scheme2.version)?.also { + state.rules.getOnEndRules(scheme1.version, scheme2.version).forEach { funcSpecBuilder.addStatement("%L", it.getInvokeCode(databaseArgName, scheme1.version, scheme2.version)) } diff --git a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/RuleByVersion.kt b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/RuleByVersion.kt index 2217b91..59166e4 100644 --- a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/RuleByVersion.kt +++ b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/RuleByVersion.kt @@ -7,3 +7,9 @@ interface RuleByVersion { val version1: Int val version2: Int } + +fun T.checkVersion(version1: Int, version2: Int): Boolean { + if (this.version1 != -1 && this.version1 != version1) return false + if (this.version2 != -1 && this.version2 != version2) return false + return true +} diff --git a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/Rules.kt b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/Rules.kt index 4206103..7a77685 100644 --- a/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/Rules.kt +++ b/RoomigrantCompiler/src/main/java/dev/matrix/roomigrant/compiler/rules/Rules.kt @@ -61,19 +61,15 @@ class Rules(private val database: Database, element: TypeElement) { } fun getFieldRule(version1: Int, version2: Int, table: String, field: String): FieldRule? { - return fieldRules[table]?.get(field)?.findByVersion(version1, version2) + return fieldRules[table]?.get(field)?.find { it.checkVersion(version1, version2) } } - fun getOnEndRule(version1: Int, version2: Int): InvokeRule? { - return onEndRules.findByVersion(version1, version2) + fun getOnEndRules(version1: Int, version2: Int): Sequence { + return onEndRules.asSequence().filter { it.checkVersion(version1, version2) } } - fun getOnStartRule(version1: Int, version2: Int): InvokeRule? { - return onStartRules.findByVersion(version1, version2) - } - - private fun Iterable.findByVersion(version1: Int, version2: Int): T? { - return find { (it.version1 == -1 || it.version1 == version1) && (it.version2 == -1 || it.version2 == version2) } + fun getOnStartRules(version1: Int, version2: Int): Sequence { + return onStartRules.asSequence().filter { it.checkVersion(version1, version2) } } } diff --git a/RoomigrantTest/src/main/java/dev/matrix/roomigrant/test/Rules.kt b/RoomigrantTest/src/main/java/dev/matrix/roomigrant/test/Rules.kt index 00d972c..8d723f8 100644 --- a/RoomigrantTest/src/main/java/dev/matrix/roomigrant/test/Rules.kt +++ b/RoomigrantTest/src/main/java/dev/matrix/roomigrant/test/Rules.kt @@ -27,9 +27,21 @@ class Rules { assert(cursor.count == 1) } + @OnMigrationStartRule(version1 = 1) + fun migrate_1_n_before(db: SupportSQLiteDatabase, version1: Int, version2: Int) { + val cursor = db.query("pragma table_info(Object1Dbo)") + assert(cursor.count == 1) + } + @OnMigrationEndRule(version1 = 1, version2 = 2) fun migrate_1_2_after(db: SupportSQLiteDatabase, version1: Int, version2: Int) { val cursor = db.query("pragma table_info(Object1Dbo)") assert(cursor.count == 3) } + + @OnMigrationEndRule(version2 = 2) + fun migrate_n_2_after(db: SupportSQLiteDatabase, version1: Int, version2: Int) { + val cursor = db.query("pragma table_info(Object1Dbo)") + assert(cursor.count == 3) + } }