Skip to content

Commit

Permalink
feat: support float32 (#531)
Browse files Browse the repository at this point in the history
* feat: support float32

Adds support for FLOAT32 columns. Applications should use the SQLAlchemy
type REAL to create a FLOAT32 column, as FLOAT is already reserved for
FLOAT64.

Fixes #409

* chore: run code formatter

* fix: remove DOUBLE reference which is SQLAlchemy 2.0-only
  • Loading branch information
olavloite authored Dec 9, 2024
1 parent dbb19c4 commit 6c3cb42
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 3 deletions.
11 changes: 11 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
"BYTES": types.LargeBinary,
"DATE": types.DATE,
"DATETIME": types.DATETIME,
"FLOAT32": types.REAL,
"FLOAT64": types.Float,
"INT64": types.BIGINT,
"NUMERIC": types.NUMERIC(precision=38, scale=9),
Expand All @@ -101,6 +102,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
types.LargeBinary: "BYTES(MAX)",
types.DATE: "DATE",
types.DATETIME: "DATETIME",
types.REAL: "FLOAT32",
types.Float: "FLOAT64",
types.BIGINT: "INT64",
types.DECIMAL: "NUMERIC",
Expand Down Expand Up @@ -540,9 +542,18 @@ class SpannerTypeCompiler(GenericTypeCompiler):
def visit_INTEGER(self, type_, **kw):
return "INT64"

def visit_DOUBLE(self, type_, **kw):
return "FLOAT64"

def visit_FLOAT(self, type_, **kw):
# Note: This was added before Spanner supported FLOAT32.
# Changing this now to generate a FLOAT32 would be a breaking change.
# Users therefore have to use REAL to generate a FLOAT32 column.
return "FLOAT64"

def visit_REAL(self, type_, **kw):
return "FLOAT32"

def visit_TEXT(self, type_, **kw):
return "STRING({})".format(type_.length or "MAX")

Expand Down
30 changes: 30 additions & 0 deletions test/mockserver_tests/float32_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy import String
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.types import REAL


class Base(DeclarativeBase):
pass


class Number(Base):
__tablename__ = "numbers"
number: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
ln: Mapped[float] = mapped_column(REAL)
73 changes: 73 additions & 0 deletions test/mockserver_tests/test_float32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy.orm import Session
from sqlalchemy.testing import (
eq_,
is_instance_of,
is_false,
)
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
ExecuteSqlRequest,
ResultSet,
ResultSetStats,
BeginTransactionRequest,
CommitRequest,
TypeCode,
)
from test.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_result,
)


class TestFloat32(MockServerTestBase):
def test_insert_data(self):
from test.mockserver_tests.float32_model import Number

update_count = ResultSet(
dict(
stats=ResultSetStats(
dict(
row_count_exact=1,
)
)
)
)
add_result(
"INSERT INTO numbers (number, name, ln) VALUES (@a0, @a1, @a2)",
update_count,
)

engine = self.create_engine()
with Session(engine) as session:
n1 = Number(number=1, name="One", ln=0.0)
session.add_all([n1])
session.commit()

requests = self.spanner_service.requests
eq_(4, len(requests))
is_instance_of(requests[0], BatchCreateSessionsRequest)
is_instance_of(requests[1], BeginTransactionRequest)
is_instance_of(requests[2], ExecuteSqlRequest)
is_instance_of(requests[3], CommitRequest)
request: ExecuteSqlRequest = requests[2]
eq_(3, len(request.params))
eq_("1", request.params["a0"])
eq_("One", request.params["a1"])
eq_(0.0, request.params["a2"])
eq_(TypeCode.INT64, request.param_types["a0"].code)
eq_(TypeCode.STRING, request.param_types["a1"].code)
is_false("a2" in request.param_types)
1 change: 0 additions & 1 deletion test/mockserver_tests/test_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestQuickStart(MockServerTestBase):
def test_create_tables(self):
from test.mockserver_tests.quickstart_model import Base

# TODO: Fix the double quotes inside these SQL fragments.
add_result(
"""SELECT true
FROM INFORMATION_SCHEMA.TABLES
Expand Down
39 changes: 37 additions & 2 deletions test/system/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
Index,
MetaData,
Boolean,
BIGINT,
)
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import REAL
from sqlalchemy.testing import eq_
from sqlalchemy.testing.plugin.plugin_base import fixtures

Expand All @@ -37,6 +40,7 @@ def define_tables(cls, metadata):
Column("name", String(20)),
Column("alternative_name", String(20)),
Column("prime", Boolean),
Column("ln", REAL),
PrimaryKeyConstraint("number"),
)
Index(
Expand All @@ -53,8 +57,8 @@ def test_hello_world(self, connection):
def test_insert_number(self, connection):
connection.execute(
text(
"""insert or update into numbers (number, name, prime)
values (1, 'One', false)"""
"""insert or update into numbers (number, name, prime, ln)
values (1, 'One', false, cast(ln(1) as float32))"""
)
)
name = connection.execute(text("select name from numbers where number=1"))
Expand All @@ -66,6 +70,17 @@ def test_reflect(self, connection):
meta.reflect(bind=engine)
eq_(1, len(meta.tables))
table = meta.tables["numbers"]
eq_(5, len(table.columns))
eq_("number", table.columns[0].name)
eq_(BIGINT, type(table.columns[0].type))
eq_("name", table.columns[1].name)
eq_(String, type(table.columns[1].type))
eq_("alternative_name", table.columns[2].name)
eq_(String, type(table.columns[2].type))
eq_("prime", table.columns[3].name)
eq_(Boolean, type(table.columns[3].type))
eq_("ln", table.columns[4].name)
eq_(REAL, type(table.columns[4].type))
eq_(1, len(table.indexes))
index = next(iter(table.indexes))
eq_(2, len(index.columns))
Expand All @@ -74,3 +89,23 @@ def test_reflect(self, connection):
dialect_options = index.dialect_options["spanner"]
eq_(1, len(dialect_options["storing"]))
eq_("alternative_name", dialect_options["storing"][0])

def test_orm(self, connection):
class Base(DeclarativeBase):
pass

class Number(Base):
__tablename__ = "numbers"
number: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(20))
alternative_name: Mapped[str] = mapped_column(String(20))
prime: Mapped[bool] = mapped_column(Boolean)
ln: Mapped[float] = mapped_column(REAL)

engine = connection.engine
with Session(engine) as session:
number = Number(
number=1, name="One", alternative_name="Uno", prime=False, ln=0.0
)
session.add(number)
session.commit()

0 comments on commit 6c3cb42

Please sign in to comment.