From aaa219b1d96a173a541dcf130a000e155d02d283 Mon Sep 17 00:00:00 2001 From: devgor88 Date: Tue, 10 Dec 2024 22:41:41 +0300 Subject: [PATCH] feat: [MariaDB] Support RETURNING clause (#2330) * feat: implement returning clause for MariaDB * fix: clean ReturningTests * fix: docs update --- .../topics/DSL-CRUD-operations.topic | 1 + .../exposed/sql/vendors/MariaDBDialect.kt | 12 ++++++++ .../sql/tests/shared/dml/ReturningTests.kt | 30 ++++++++++--------- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/documentation-website/Writerside/topics/DSL-CRUD-operations.topic b/documentation-website/Writerside/topics/DSL-CRUD-operations.topic index 0fba69a6c2..f00ea9e957 100644 --- a/documentation-website/Writerside/topics/DSL-CRUD-operations.topic +++ b/documentation-website/Writerside/topics/DSL-CRUD-operations.topic @@ -406,6 +406,7 @@

Supported on: PostgreSQL and SQLite

Some databases allow the return of additional data every time a row is either inserted, updated, or deleted. + Please note that MariaDB only allows the return of this data for insertions and deletions. This can be accomplished by using one of the following functions:

diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MariaDBDialect.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MariaDBDialect.kt index 4299efb852..623627f3e1 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MariaDBDialect.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MariaDBDialect.kt @@ -48,6 +48,18 @@ internal object MariaDBFunctionProvider : MysqlFunctionProvider() { sql } } + + override fun returning( + mainSql: String, + returning: List>, + transaction: Transaction + ): String { + return with(QueryBuilder(true)) { + +"$mainSql RETURNING " + returning.appendTo { +it } + toString() + } + } } /** diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt index a6cd0b0d2b..422a3402f1 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt @@ -17,7 +17,8 @@ import kotlin.test.assertIs import kotlin.test.assertTrue class ReturningTests : DatabaseTestsBase() { - private val returningSupportedDb = TestDB.ALL_POSTGRES.toSet() + TestDB.SQLITE + private val updateReturningSupportedDb = TestDB.ALL_POSTGRES.toSet() + TestDB.SQLITE + private val returningSupportedDb = updateReturningSupportedDb + TestDB.MARIADB object Items : IntIdTable("items") { val name = varchar("name", 32) @@ -89,7 +90,7 @@ class ReturningTests : DatabaseTestsBase() { @Test fun testUpsertReturning() { - withTables(TestDB.ALL - returningSupportedDb, Items) { + withTables(TestDB.ALL - returningSupportedDb, Items) { testDB -> // return all columns by default val result1 = Items.upsertReturning { it[name] = "A" @@ -110,17 +111,18 @@ class ReturningTests : DatabaseTestsBase() { assertEquals("A", result2[Items.name]) assertEquals(990.0, result2[Items.price]) - val result3 = Items.upsertReturning( - returning = listOf(Items.name), - onUpdateExclude = listOf(Items.price), - where = { Items.price greater 500.0 } - ) { - it[id] = 1 - it[name] = "B" - it[price] = 200.0 - }.single() - assertEquals("B", result3[Items.name]) - + if (testDB != TestDB.MARIADB) { + val result3 = Items.upsertReturning( + returning = listOf(Items.name), + onUpdateExclude = listOf(Items.price), + where = { Items.price greater 500.0 } + ) { + it[id] = 1 + it[name] = "B" + it[price] = 200.0 + }.single() + assertEquals("B", result3[Items.name]) + } assertEquals(1, Items.selectAll().count()) } } @@ -196,7 +198,7 @@ class ReturningTests : DatabaseTestsBase() { @Test fun testUpdateReturning() { - withTables(TestDB.enabledDialects() - returningSupportedDb, Items) { + withTables(TestDB.enabledDialects() - updateReturningSupportedDb, Items) { val input = listOf("A" to 99.0, "B" to 100.0, "C" to 200.0) Items.batchInsert(input) { (n, p) -> this[Items.name] = n