From ac644726665213f234ce8ec4dea715c820a670e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 28 Nov 2024 14:22:34 +0100 Subject: [PATCH] fix: support THEN RETURN for insert, update, delete (#503) * fix: support THEN RETURN for insert, update, delete Support THEN RETURN clauses for INSERT, UPDATE, and DELETE statements. Fixes #498 * test: override insert auto-generated pk test --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 4 + noxfile.py | 2 +- .../bit_reversed_sequence_model.py | 33 +++++ .../test_bit_reversed_sequence.py | 137 ++++++++++++++++++ test/test_suite_20.py | 26 ++++ 5 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 test/mockserver_tests/bit_reversed_sequence_model.py create mode 100644 test/mockserver_tests/test_bit_reversed_sequence.py diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index e57d46a8..e2fb651d 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -611,6 +611,10 @@ class SpannerDialect(DefaultDialect): supports_native_decimal = True supports_statement_cache = True + insert_returning = True + update_returning = True + delete_returning = True + ddl_compiler = SpannerDDLCompiler preparer = SpannerIdentifierPreparer statement_compiler = SpannerSQLCompiler diff --git a/noxfile.py b/noxfile.py index 2c4b21bc..974daf99 100644 --- a/noxfile.py +++ b/noxfile.py @@ -327,7 +327,7 @@ def mockserver(session): "9999", ) session.run( - "py.test", "--quiet", os.path.join("test/mockserver_tests"), *session.posargs + "py.test", "--quiet", os.path.join("test", "mockserver_tests"), *session.posargs ) diff --git a/test/mockserver_tests/bit_reversed_sequence_model.py b/test/mockserver_tests/bit_reversed_sequence_model.py new file mode 100644 index 00000000..b76cdd3f --- /dev/null +++ b/test/mockserver_tests/bit_reversed_sequence_model.py @@ -0,0 +1,33 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import String, BigInteger, Sequence, TextClause +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Singer(Base): + __tablename__ = "singers" + id: Mapped[int] = mapped_column( + BigInteger, + Sequence("singer_id"), + server_default=TextClause("GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id)"), + primary_key=True, + ) + name: Mapped[str] = mapped_column(String) diff --git a/test/mockserver_tests/test_bit_reversed_sequence.py b/test/mockserver_tests/test_bit_reversed_sequence.py new file mode 100644 index 00000000..82822d44 --- /dev/null +++ b/test/mockserver_tests/test_bit_reversed_sequence.py @@ -0,0 +1,137 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of +from google.cloud.spanner_v1 import ( + FixedSizePool, + ResultSet, + BatchCreateSessionsRequest, + ExecuteSqlRequest, + CommitRequest, + GetSessionRequest, + BeginTransactionRequest, +) +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_result, +) +from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set + + +class TestBitReversedSequence(MockServerTestBase): + def test_create_table(self): + from test.mockserver_tests.bit_reversed_sequence_model import Base + + add_result( + """SELECT true +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_SCHEMA="" AND TABLE_NAME="singers" +LIMIT 1 +""", + ResultSet(), + ) + add_result( + """SELECT true + FROM INFORMATION_SCHEMA.SEQUENCES + WHERE NAME="singer_id" + AND SCHEMA="" + LIMIT 1""", + ResultSet(), + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + Base.metadata.create_all(engine) + requests = self.database_admin_service.requests + eq_(1, len(requests)) + is_instance_of(requests[0], UpdateDatabaseDdlRequest) + eq_(2, len(requests[0].statements)) + eq_( + "CREATE SEQUENCE singer_id OPTIONS " + "(sequence_kind = 'bit_reversed_positive')", + requests[0].statements[0], + ) + eq_( + "CREATE TABLE singers (\n" + "\tid INT64 NOT NULL DEFAULT " + "(GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id)), \n" + "\tname STRING(MAX) NOT NULL\n" + ") PRIMARY KEY (id)", + requests[0].statements[1], + ) + + def test_insert_row(self): + from test.mockserver_tests.bit_reversed_sequence_model import Singer + + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name="id", + type=spanner_type.Type( + dict(code=spanner_type.TypeCode.INT64) + ), + ) + ) + ] + ) + ) + ) + ), + stats=result_set.ResultSetStats( + dict( + row_count_exact=1, + ) + ), + ) + ) + result.rows.extend(["1"]) + + add_result( + "INSERT INTO singers (id, name) " + "VALUES ( GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id), @a0) " + "THEN RETURN singers.id", + result, + ) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + with Session(engine) as session: + singer = Singer(name="Test") + session.add(singer) + # Flush the session to send the insert statement to the database. + session.flush() + eq_(1, singer.id) + session.commit() + # Verify the requests that we got. + requests = self.spanner_service.requests + eq_(5, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + # We should get rid of this extra round-trip for GetSession.... + is_instance_of(requests[1], GetSessionRequest) + is_instance_of(requests[2], BeginTransactionRequest) + is_instance_of(requests[3], ExecuteSqlRequest) + is_instance_of(requests[4], CommitRequest) diff --git a/test/test_suite_20.py b/test/test_suite_20.py index 22b23e0a..dbbc8f88 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -2171,6 +2171,32 @@ def test_autoclose_on_insert(self): assert r.is_insert assert not r.returns_rows + def test_autoclose_on_insert_implicit_returning(self, connection): + """ + SPANNER OVERRIDE: + + Cloud Spanner doesn't support tables with an auto increment primary key, + following insertions will fail with `400 id must not be NULL in table + autoinc_pk`. + + Overriding the tests and adding a manual primary key value to avoid the same + failures. + """ + r = connection.execute( + # return_defaults() ensures RETURNING will be used, + # new in 2.0 as sqlite/mariadb offer both RETURNING and + # cursor.lastrowid + self.tables.autoinc_pk.insert().return_defaults(), + dict(id=2, data="some data"), + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # Spanner does not return any rows in this case, because the primary key + # is not auto-generated. + assert not r.returns_rows + class BytesTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True