Skip to content

Commit

Permalink
feat: support Partitioned DML
Browse files Browse the repository at this point in the history
Adds tests and samples for executing Partitioned DML using SQLAlchemy.

Fixes #496
  • Loading branch information
olavloite committed Dec 5, 2024
1 parent a633c23 commit 4712ec3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
13 changes: 13 additions & 0 deletions test/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def add_result(sql: str, result: ResultSet):
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


def add_update_count(sql: str, count: int):
result = result_set.ResultSet(
dict(
stats=result_set.ResultSetStats(
dict(
row_count_exact=count,
)
),
)
)
add_result(sql, result)


def add_select1_result():
result = result_set.ResultSet(
dict(
Expand Down
25 changes: 19 additions & 6 deletions test/mockserver_tests/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata
from google.cloud.spanner_v1 import (
TransactionOptions,
ResultSetMetadata,
ExecuteSqlRequest,
)
from google.protobuf import empty_pb2
import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc
import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc
Expand Down Expand Up @@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet:
return result

def get_result_as_partial_result_sets(
self, sql: str
self, sql: str, started_transaction: transaction.Transaction
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials


Expand Down Expand Up @@ -120,9 +126,16 @@ def ExecuteSql(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context):
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
started_transaction = None
if not request.transaction.begin == TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
partials = self.mock_spanner.get_result_as_partial_result_sets(
request.sql, started_transaction
)
for result in partials:
yield result

Expand Down
35 changes: 34 additions & 1 deletion test/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@
# limitations under the License.

from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import (
create_engine,
select,
MetaData,
Table,
Column,
Integer,
String,
update,
table,
column,
)
from sqlalchemy.testing import eq_, is_instance_of
from google.cloud.spanner_v1 import (
FixedSizePool,
Expand All @@ -26,6 +38,7 @@
MockServerTestBase,
add_select1_result,
add_result,
add_update_count,
)


Expand Down Expand Up @@ -127,3 +140,23 @@ def test_create_multiple_tables(self):
"\n) PRIMARY KEY (id)",
requests[0].statements[i],
)

def test_partitioned_dml(self):
sql = "UPDATE singers SET WHERE active = true"
add_update_count(sql, 100)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": PingingPool(size=10)},
)
# TODO: Support autocommit_dml_mode as a connection variable in execution
# options.
with engine.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as connection:
connection.connection.set_autocommit_dml_mode(
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)
results = connection.execute(
update(table("singers")).where(column("active") is True)
).rowcount
eq_(100, results)

0 comments on commit 4712ec3

Please sign in to comment.