From 4b45894f6b34d99c85c17ae590fc749e831f388e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 5 Dec 2024 19:04:16 +0100 Subject: [PATCH] feat: support Partitioned DML Adds tests and samples for executing Partitioned DML using SQLAlchemy. Fixes #496 --- samples/partitioned_dml_sample.py | 45 +++++++++++++++++++ .../mockserver_tests/mock_server_test_base.py | 12 +++++ test/mockserver_tests/mock_spanner.py | 25 ++++++++--- test/mockserver_tests/test_basics.py | 31 ++++++++++++- 4 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 samples/partitioned_dml_sample.py diff --git a/samples/partitioned_dml_sample.py b/samples/partitioned_dml_sample.py new file mode 100644 index 00000000..62c312ff --- /dev/null +++ b/samples/partitioned_dml_sample.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode +from sqlalchemy import create_engine, text + +from sample_helper import run_sample + +# Shows how to use Partitioned DML using SQLAlchemy and Spanner. +def partitioned_dml_sample(): + engine = create_engine( + "spanner:///projects/sample-project/" + "instances/sample-instance/" + "databases/sample-database", + echo=True, + ) + # Get a connection in auto-commit mode. + # Partitioned DML can only be executed in auto-commit mode, as each + # Partitioned DML transaction can only consist of one statement. + with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection: + # Set the DML mode to PARTITIONED_NON_ATOMIC. + connection.connection.set_autocommit_dml_mode( + AutocommitDmlMode.PARTITIONED_NON_ATOMIC + ) + # Use a bulk update statement to back-fill a column. + lower_bound_rowcount = connection.execute( + text("update venues set active=true where active is null") + ).rowcount + # Partitioned DML returns the lower-bound update count. + print("Updated at least ", lower_bound_rowcount, " venue records") + + +if __name__ == "__main__": + run_sample(partitioned_dml_sample) diff --git a/test/mockserver_tests/mock_server_test_base.py b/test/mockserver_tests/mock_server_test_base.py index 71e1bf1f..5aa33732 100644 --- a/test/mockserver_tests/mock_server_test_base.py +++ b/test/mockserver_tests/mock_server_test_base.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode from sqlalchemy import Engine, create_engine from sqlalchemy.testing.plugin.plugin_base import fixtures import google.cloud.spanner_v1.types.type as spanner_type @@ -35,6 +36,17 @@ def add_result(sql: str, result: ResultSet): MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result) +def add_update_count( + sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL +): + if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC: + stats = dict(row_count_lower_bound=count) + else: + stats = dict(row_count_exact=count) + result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats))) + add_result(sql, result) + + def add_select1_result(): result = result_set.ResultSet( dict( diff --git a/test/mockserver_tests/mock_spanner.py b/test/mockserver_tests/mock_spanner.py index db189144..932f6371 100644 --- a/test/mockserver_tests/mock_spanner.py +++ b/test/mockserver_tests/mock_spanner.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata +from google.cloud.spanner_v1 import ( + TransactionOptions, + ResultSetMetadata, + ExecuteSqlRequest, +) from google.protobuf import empty_pb2 import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc @@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet: return result def get_result_as_partial_result_sets( - self, sql: str + self, sql: str, started_transaction: transaction.Transaction ) -> [result_set.PartialResultSet]: result: result_set.ResultSet = self.get_result(sql) partials = [] first = True if len(result.rows) == 0: partial = result_set.PartialResultSet() - partial.metadata = result.metadata + partial.metadata = ResultSetMetadata(result.metadata) partials.append(partial) else: for row in result.rows: partial = result_set.PartialResultSet() if first: - partial.metadata = result.metadata + partial.metadata = ResultSetMetadata(result.metadata) partial.values.extend(row) partials.append(partial) partials[len(partials) - 1].stats = result.stats + if started_transaction: + partials[0].metadata.transaction = started_transaction return partials @@ -120,9 +126,16 @@ def ExecuteSql(self, request, context): self._requests.append(request) return result_set.ResultSet() - def ExecuteStreamingSql(self, request, context): + def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context): self._requests.append(request) - partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql) + started_transaction = None + if not request.transaction.begin == TransactionOptions(): + started_transaction = self.__create_transaction( + request.session, request.transaction.begin + ) + partials = self.mock_spanner.get_result_as_partial_result_sets( + request.sql, started_transaction + ) for result in partials: yield result diff --git a/test/mockserver_tests/test_basics.py b/test/mockserver_tests/test_basics.py index b6c916c4..82918366 100644 --- a/test/mockserver_tests/test_basics.py +++ b/test/mockserver_tests/test_basics.py @@ -13,7 +13,17 @@ # limitations under the License. from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest -from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode +from sqlalchemy import ( + create_engine, + select, + MetaData, + Table, + Column, + Integer, + String, + text, +) from sqlalchemy.testing import eq_, is_instance_of from google.cloud.spanner_v1 import ( FixedSizePool, @@ -26,6 +36,7 @@ MockServerTestBase, add_select1_result, add_result, + add_update_count, ) @@ -127,3 +138,21 @@ def test_create_multiple_tables(self): "\n) PRIMARY KEY (id)", requests[0].statements[i], ) + + def test_partitioned_dml(self): + sql = "UPDATE singers SET checked=true WHERE active = true" + add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": PingingPool(size=10)}, + ) + # TODO: Support autocommit_dml_mode as a connection variable in execution + # options. + with engine.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection: + connection.connection.set_autocommit_dml_mode( + AutocommitDmlMode.PARTITIONED_NON_ATOMIC + ) + results = connection.execute(text(sql)).rowcount + eq_(100, results)