Skip to content

Commit

Permalink
feat: support Partitioned DML
Browse files Browse the repository at this point in the history
Adds tests and samples for executing Partitioned DML using SQLAlchemy.

Fixes #496
  • Loading branch information
olavloite committed Dec 6, 2024
1 parent a633c23 commit 4b45894
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 7 deletions.
45 changes: 45 additions & 0 deletions samples/partitioned_dml_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import create_engine, text

from sample_helper import run_sample

# Shows how to use Partitioned DML using SQLAlchemy and Spanner.
def partitioned_dml_sample():
engine = create_engine(
"spanner:///projects/sample-project/"
"instances/sample-instance/"
"databases/sample-database",
echo=True,
)
# Get a connection in auto-commit mode.
# Partitioned DML can only be executed in auto-commit mode, as each
# Partitioned DML transaction can only consist of one statement.
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
# Set the DML mode to PARTITIONED_NON_ATOMIC.
connection.connection.set_autocommit_dml_mode(
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)
# Use a bulk update statement to back-fill a column.
lower_bound_rowcount = connection.execute(
text("update venues set active=true where active is null")
).rowcount
# Partitioned DML returns the lower-bound update count.
print("Updated at least ", lower_bound_rowcount, " venue records")


if __name__ == "__main__":
run_sample(partitioned_dml_sample)
12 changes: 12 additions & 0 deletions test/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import Engine, create_engine
from sqlalchemy.testing.plugin.plugin_base import fixtures
import google.cloud.spanner_v1.types.type as spanner_type
Expand All @@ -35,6 +36,17 @@ def add_result(sql: str, result: ResultSet):
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


def add_update_count(
sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
add_result(sql, result)


def add_select1_result():
result = result_set.ResultSet(
dict(
Expand Down
25 changes: 19 additions & 6 deletions test/mockserver_tests/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata
from google.cloud.spanner_v1 import (
TransactionOptions,
ResultSetMetadata,
ExecuteSqlRequest,
)
from google.protobuf import empty_pb2
import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc
import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc
Expand Down Expand Up @@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet:
return result

def get_result_as_partial_result_sets(
self, sql: str
self, sql: str, started_transaction: transaction.Transaction
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials


Expand Down Expand Up @@ -120,9 +126,16 @@ def ExecuteSql(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context):
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
started_transaction = None
if not request.transaction.begin == TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
partials = self.mock_spanner.get_result_as_partial_result_sets(
request.sql, started_transaction
)
for result in partials:
yield result

Expand Down
31 changes: 30 additions & 1 deletion test/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.

from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import (
create_engine,
select,
MetaData,
Table,
Column,
Integer,
String,
text,
)
from sqlalchemy.testing import eq_, is_instance_of
from google.cloud.spanner_v1 import (
FixedSizePool,
Expand All @@ -26,6 +36,7 @@
MockServerTestBase,
add_select1_result,
add_result,
add_update_count,
)


Expand Down Expand Up @@ -127,3 +138,21 @@ def test_create_multiple_tables(self):
"\n) PRIMARY KEY (id)",
requests[0].statements[i],
)

def test_partitioned_dml(self):
sql = "UPDATE singers SET checked=true WHERE active = true"
add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": PingingPool(size=10)},
)
# TODO: Support autocommit_dml_mode as a connection variable in execution
# options.
with engine.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as connection:
connection.connection.set_autocommit_dml_mode(
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)
results = connection.execute(text(sql)).rowcount
eq_(100, results)

0 comments on commit 4b45894

Please sign in to comment.