Skip to content

Commit

Permalink
fix: support THEN RETURN for insert, update, delete (#503)
Browse files Browse the repository at this point in the history
* fix: support THEN RETURN for insert, update, delete

Support THEN RETURN clauses for INSERT, UPDATE, and DELETE statements.

Fixes #498

* test: override insert auto-generated pk test
  • Loading branch information
olavloite authored Nov 28, 2024
1 parent 142cbee commit ac64472
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 1 deletion.
4 changes: 4 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ class SpannerDialect(DefaultDialect):
supports_native_decimal = True
supports_statement_cache = True

insert_returning = True
update_returning = True
delete_returning = True

ddl_compiler = SpannerDDLCompiler
preparer = SpannerIdentifierPreparer
statement_compiler = SpannerSQLCompiler
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def mockserver(session):
"9999",
)
session.run(
"py.test", "--quiet", os.path.join("test/mockserver_tests"), *session.posargs
"py.test", "--quiet", os.path.join("test", "mockserver_tests"), *session.posargs
)


Expand Down
33 changes: 33 additions & 0 deletions test/mockserver_tests/bit_reversed_sequence_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy import String, BigInteger, Sequence, TextClause
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column


class Base(DeclarativeBase):
pass


class Singer(Base):
__tablename__ = "singers"
id: Mapped[int] = mapped_column(
BigInteger,
Sequence("singer_id"),
server_default=TextClause("GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id)"),
primary_key=True,
)
name: Mapped[str] = mapped_column(String)
137 changes: 137 additions & 0 deletions test/mockserver_tests/test_bit_reversed_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_, is_instance_of
from google.cloud.spanner_v1 import (
FixedSizePool,
ResultSet,
BatchCreateSessionsRequest,
ExecuteSqlRequest,
CommitRequest,
GetSessionRequest,
BeginTransactionRequest,
)
from test.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_result,
)
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set


class TestBitReversedSequence(MockServerTestBase):
def test_create_table(self):
from test.mockserver_tests.bit_reversed_sequence_model import Base

add_result(
"""SELECT true
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA="" AND TABLE_NAME="singers"
LIMIT 1
""",
ResultSet(),
)
add_result(
"""SELECT true
FROM INFORMATION_SCHEMA.SEQUENCES
WHERE NAME="singer_id"
AND SCHEMA=""
LIMIT 1""",
ResultSet(),
)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
)
Base.metadata.create_all(engine)
requests = self.database_admin_service.requests
eq_(1, len(requests))
is_instance_of(requests[0], UpdateDatabaseDdlRequest)
eq_(2, len(requests[0].statements))
eq_(
"CREATE SEQUENCE singer_id OPTIONS "
"(sequence_kind = 'bit_reversed_positive')",
requests[0].statements[0],
)
eq_(
"CREATE TABLE singers (\n"
"\tid INT64 NOT NULL DEFAULT "
"(GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id)), \n"
"\tname STRING(MAX) NOT NULL\n"
") PRIMARY KEY (id)",
requests[0].statements[1],
)

def test_insert_row(self):
from test.mockserver_tests.bit_reversed_sequence_model import Singer

result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name="id",
type=spanner_type.Type(
dict(code=spanner_type.TypeCode.INT64)
),
)
)
]
)
)
)
),
stats=result_set.ResultSetStats(
dict(
row_count_exact=1,
)
),
)
)
result.rows.extend(["1"])

add_result(
"INSERT INTO singers (id, name) "
"VALUES ( GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id), @a0) "
"THEN RETURN singers.id",
result,
)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
)

with Session(engine) as session:
singer = Singer(name="Test")
session.add(singer)
# Flush the session to send the insert statement to the database.
session.flush()
eq_(1, singer.id)
session.commit()
# Verify the requests that we got.
requests = self.spanner_service.requests
eq_(5, len(requests))
is_instance_of(requests[0], BatchCreateSessionsRequest)
# We should get rid of this extra round-trip for GetSession....
is_instance_of(requests[1], GetSessionRequest)
is_instance_of(requests[2], BeginTransactionRequest)
is_instance_of(requests[3], ExecuteSqlRequest)
is_instance_of(requests[4], CommitRequest)
26 changes: 26 additions & 0 deletions test/test_suite_20.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,32 @@ def test_autoclose_on_insert(self):
assert r.is_insert
assert not r.returns_rows

def test_autoclose_on_insert_implicit_returning(self, connection):
"""
SPANNER OVERRIDE:
Cloud Spanner doesn't support tables with an auto increment primary key,
following insertions will fail with `400 id must not be NULL in table
autoinc_pk`.
Overriding the tests and adding a manual primary key value to avoid the same
failures.
"""
r = connection.execute(
# return_defaults() ensures RETURNING will be used,
# new in 2.0 as sqlite/mariadb offer both RETURNING and
# cursor.lastrowid
self.tables.autoinc_pk.insert().return_defaults(),
dict(id=2, data="some data"),
)
assert r._soft_closed
assert not r.closed
assert r.is_insert

# Spanner does not return any rows in this case, because the primary key
# is not auto-generated.
assert not r.returns_rows


class BytesTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
Expand Down

0 comments on commit ac64472

Please sign in to comment.