From 0cd039a3483a4785f15cb7396a752212536fb631 Mon Sep 17 00:00:00 2001 From: hantmac Date: Mon, 17 Jun 2024 16:22:22 +0800 Subject: [PATCH] fix: casting 'False' to boolean returns True --- databend_py/datetypes.py | 8 +++++++- tests/test_client.py | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/databend_py/datetypes.py b/databend_py/datetypes.py index 8e66c1a..3be65af 100644 --- a/databend_py/datetypes.py +++ b/databend_py/datetypes.py @@ -23,7 +23,7 @@ def type_convert_fn(type_str: str): elif DOUBLETYPE in type_str.lower(): return float elif BOOLEANTYPE in type_str.lower(): - return bool + return str_to_bool elif MAPTYPE in type_str.lower(): return ast.literal_eval elif ARRAYTYPE in type_str.lower(): @@ -34,6 +34,12 @@ def type_convert_fn(type_str: str): return str +def str_to_bool(s): + if isinstance(s, str) and s.isdigit(): + return bool(int(s)) + return bool(s) + + if __name__ == '__main__': d = DatabendDataType() print(d.type_convert_fn("Uint64")('0')) diff --git a/tests/test_client.py b/tests/test_client.py index 267ba05..da3a3af 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -252,6 +252,11 @@ def test_rollback(self): _, data = client.execute("select * from test_rollback") self.assertEqual(data, []) + def test_cast_bool(self): + client = Client.from_url(self.databend_url) + _, data = client.execute("select 'False'::boolean union select 'True'::boolean") + self.assertEqual(data, [(True,), (False,)]) + if __name__ == '__main__': print("start test......") @@ -269,4 +274,5 @@ def test_rollback(self): dt.test_cookies() dt.test_null_to_none() dt.tearDown() + dt.test_cast_bool() print("end test.....")