From 5071e3dfe4cf345b88c295bffddec25fcf23f27d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 2 Dec 2024 17:30:21 +0100 Subject: [PATCH] feat: support float32 Adds support for FLOAT32 columns. Applications should use the SQLAlchemy type REAL to create a FLOAT32 column, as FLOAT is already reserved for FLOAT64. Fixes #409 --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 14 ++++ test/mockserver_tests/float32_model.py | 30 ++++++++ test/mockserver_tests/test_float32.py | 72 +++++++++++++++++++ test/mockserver_tests/test_quickstart.py | 1 - test/system/test_basics.py | 38 +++++++++- 5 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 test/mockserver_tests/float32_model.py create mode 100644 test/mockserver_tests/test_float32.py diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index e2fb651d..6786ec61 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -84,6 +84,9 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None): "BYTES": types.LargeBinary, "DATE": types.DATE, "DATETIME": types.DATETIME, + "FLOAT32": types.REAL, + # Note: FLOAT64 was mapped to Float when Spanner only supported FLOAT64 + # This should however rather have been types.Double. "FLOAT64": types.Float, "INT64": types.BIGINT, "NUMERIC": types.NUMERIC(precision=38, scale=9), @@ -101,7 +104,9 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None): types.LargeBinary: "BYTES(MAX)", types.DATE: "DATE", types.DATETIME: "DATETIME", + types.REAL: "FLOAT32", types.Float: "FLOAT64", + types.DOUBLE: "FLOAT64", types.BIGINT: "INT64", types.DECIMAL: "NUMERIC", types.String: "STRING", @@ -540,9 +545,18 @@ class SpannerTypeCompiler(GenericTypeCompiler): def visit_INTEGER(self, type_, **kw): return "INT64" + def visit_DOUBLE(self, type_, **kw): + return "FLOAT64" + def visit_FLOAT(self, type_, **kw): + # Note: This was added before Spanner supported FLOAT32. + # Changing this now to generate a FLOAT32 would be a breaking change. + # Users therefore have to use REAL to generate a FLOAT32 column. return "FLOAT64" + def visit_REAL(self, type_, **kw): + return "FLOAT32" + def visit_TEXT(self, type_, **kw): return "STRING({})".format(type_.length or "MAX") diff --git a/test/mockserver_tests/float32_model.py b/test/mockserver_tests/float32_model.py new file mode 100644 index 00000000..b6987e97 --- /dev/null +++ b/test/mockserver_tests/float32_model.py @@ -0,0 +1,30 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import REAL + + +class Base(DeclarativeBase): + pass + + +class Number(Base): + __tablename__ = "numbers" + number: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(30)) + ln: Mapped[float] = mapped_column(REAL) diff --git a/test/mockserver_tests/test_float32.py b/test/mockserver_tests/test_float32.py new file mode 100644 index 00000000..3845231b --- /dev/null +++ b/test/mockserver_tests/test_float32.py @@ -0,0 +1,72 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest +from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_, is_instance_of, is_not_none, is_none, \ + is_false, is_true +from google.cloud.spanner_v1 import ( + FixedSizePool, + BatchCreateSessionsRequest, + ExecuteSqlRequest, + ResultSet, + PingingPool, ResultSetStats, BeginTransactionRequest, + ExecuteBatchDmlRequest, CommitRequest, TypeCode, +) +from test.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, + add_result, +) + + +class TestFloat32(MockServerTestBase): + def test_insert_data(self): + from test.mockserver_tests.float32_model import Number + + update_count = ResultSet( + dict( + stats=ResultSetStats( + dict( + row_count_exact=1, + ) + ) + ) + ) + add_result( + "INSERT INTO numbers (number, name, ln) VALUES (@a0, @a1, @a2)", + update_count, + ) + + engine = self.create_engine() + with Session(engine) as session: + n1 = Number(number=1, name="One", ln=0.0) + session.add_all([n1]) + session.commit() + + requests = self.spanner_service.requests + eq_(4, len(requests)) + is_instance_of(requests[0], BatchCreateSessionsRequest) + is_instance_of(requests[1], BeginTransactionRequest) + is_instance_of(requests[2], ExecuteSqlRequest) + is_instance_of(requests[3], CommitRequest) + request: ExecuteSqlRequest = requests[2] + eq_(3, len(request.params)) + eq_("1", request.params["a0"]) + eq_("One", request.params["a1"]) + eq_(0.0, request.params["a2"]) + eq_(TypeCode.INT64, request.param_types["a0"].code) + eq_(TypeCode.STRING, request.param_types["a1"].code) + is_false("a2" in request.param_types) diff --git a/test/mockserver_tests/test_quickstart.py b/test/mockserver_tests/test_quickstart.py index ce9711f7..0b31f9e2 100644 --- a/test/mockserver_tests/test_quickstart.py +++ b/test/mockserver_tests/test_quickstart.py @@ -30,7 +30,6 @@ class TestQuickStart(MockServerTestBase): def test_create_tables(self): from test.mockserver_tests.quickstart_model import Base - # TODO: Fix the double quotes inside these SQL fragments. add_result( """SELECT true FROM INFORMATION_SCHEMA.TABLES diff --git a/test/system/test_basics.py b/test/system/test_basics.py index 3357104c..edbba54b 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -21,8 +21,10 @@ String, Index, MetaData, - Boolean, + Boolean, BIGINT, ) +from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column +from sqlalchemy.types import REAL from sqlalchemy.testing import eq_ from sqlalchemy.testing.plugin.plugin_base import fixtures @@ -37,6 +39,7 @@ def define_tables(cls, metadata): Column("name", String(20)), Column("alternative_name", String(20)), Column("prime", Boolean), + Column("ln", REAL), PrimaryKeyConstraint("number"), ) Index( @@ -53,8 +56,8 @@ def test_hello_world(self, connection): def test_insert_number(self, connection): connection.execute( text( - """insert or update into numbers (number, name, prime) - values (1, 'One', false)""" + """insert or update into numbers (number, name, prime, ln) + values (1, 'One', false, cast(ln(1) as float32))""" ) ) name = connection.execute(text("select name from numbers where number=1")) @@ -66,6 +69,17 @@ def test_reflect(self, connection): meta.reflect(bind=engine) eq_(1, len(meta.tables)) table = meta.tables["numbers"] + eq_(5, len(table.columns)) + eq_("number", table.columns[0].name) + eq_(BIGINT, type(table.columns[0].type)) + eq_("name", table.columns[1].name) + eq_(String, type(table.columns[1].type)) + eq_("alternative_name", table.columns[2].name) + eq_(String, type(table.columns[2].type)) + eq_("prime", table.columns[3].name) + eq_(Boolean, type(table.columns[3].type)) + eq_("ln", table.columns[4].name) + eq_(REAL, type(table.columns[4].type)) eq_(1, len(table.indexes)) index = next(iter(table.indexes)) eq_(2, len(index.columns)) @@ -74,3 +88,21 @@ def test_reflect(self, connection): dialect_options = index.dialect_options["spanner"] eq_(1, len(dialect_options["storing"])) eq_("alternative_name", dialect_options["storing"][0]) + + def test_orm(self, connection): + class Base(DeclarativeBase): + pass + + class Number(Base): + __tablename__ = "numbers" + number: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(20)) + alternative_name: Mapped[str] = mapped_column(String(20)) + prime: Mapped[bool] = mapped_column(Boolean) + ln: Mapped[float] = mapped_column(REAL) + + engine = connection.engine + with Session(engine) as session: + number = Number(number=1, name="One", alternative_name="Uno", prime=False, ln=0.0) + session.add(number) + session.commit()