From aa7301adfc6928beff1b46a8ba172071b1c7e398 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Tue, 17 May 2022 13:29:31 +0100 Subject: [PATCH 01/14] implement where api --- exetera/core/abstract_types.py | 4 + exetera/core/fields.py | 42 ++++++ tests/test_fields.py | 267 +++++++++++++++++++++++++++++++++ 3 files changed, 313 insertions(+) diff --git a/exetera/core/abstract_types.py b/exetera/core/abstract_types.py index 48bede9..6b725f9 100644 --- a/exetera/core/abstract_types.py +++ b/exetera/core/abstract_types.py @@ -85,6 +85,10 @@ def isin(self, test_elements:Union[list, set, np.ndarray]): def unique(self, return_index=False, return_inverse=False, return_counts=False): raise NotImplementedError() + @staticmethod + def where(cond, a, b): + raise NotImplementedError() + class Dataset(ABC): """ diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 1cb0e42..f7c6c18 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -39,6 +39,23 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]): return ret +def where(cond, a, b): + if isinstance(cond, list) or (isinstance(cond, np.ndarray) and cond.dtype == 'bool'): + cond = cond + elif isinstance(cond, Field): + if cond.indexed: + raise NotImplementedError("Where does not support indexed string fields at present") + cond = cond.data[:] + elif callable(cond): + raise NotImplementedError("fields.where doesn't support callable cond") + + if isinstance(a, Field): + a = a.data[:] + if isinstance(b, Field): + b = b.data[:] + return np.where(cond, a, b) + + class HDF5Field(Field): def __init__(self, session, group, dataframe, write_enabled=False): super().__init__() @@ -143,6 +160,31 @@ def _ensure_valid(self): if not self._valid_reference: raise ValueError("This field no longer refers to a valid underlying field object") + def where(self, cond, b, inplace=False): + + if callable(cond): + cond = cond(self.data[:]) + elif isinstance(cond, list) or (isinstance(cond, np.ndarray) and cond.dtype == 'bool'): + cond = cond + elif isinstance(cond, Field): + if cond.indexed: + raise NotImplementedError("Where does not support indexed string fields at present") + cond = cond.data[:] + else: + raise TypeError("'cond' parameter needs to be either callable lambda function, or boolean ndarray, or NumericMemField") + + if isinstance(b, str): + b = b.encode() + if isinstance(b, Field): + b = b.data[:] + + result = np.where(cond, self.data[:], b) + + if inplace: + self.data.clear() + self.data.write(result) + return result + class MemoryField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index d061233..b45d420 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2169,6 +2169,273 @@ def test_indexed_string_isin(self, data, isin_data, expected): np.testing.assert_array_equal(expected, result) +MODULE_WHERE_TESTS = [ + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 0, None), + ( + [False, False, False, False, True, False, True, True, False, True, False], + [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', + 9, None + ), + ( + [False, False, False, False, True, False, True, True, False, True, False], + [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'int64', + ), +] + +INSTANCE_WHERE_NUMERIC_TESTS = [ + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 0, None), + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 9, None), + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', -1, None), + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'int32'), + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], 'int64'), + (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], 'float32'), + # ( + # [False, False, False, False, True, False, True, True, False, True, False], + # [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], + + # ) + +] + +def where_oracle(cond, a, b): + if callable(cond): + if isinstance(a, fields.Field): + cond = cond(a.data[:]) + elif isinstance(a, list): + cond = cond(np.array(a)) + elif isinstance(a, np.ndarray): + cond = cond(a) + return np.where(cond, a, b) + + +class TestFieldWhereFunctions(SessionTestCase): + + @parameterized.expand(MODULE_WHERE_TESTS) + def test_module_field_where(self, cond, a_field_data, a_field_dtype, b_data, b_dtype): + """ + Test `where` for the numeric fields using `fields.where` function and the object's method. + """ + a_field = self.setup_field(self.df, "create_numeric", "af", (a_field_dtype,), {}, a_field_data) + + + expected_result = where_oracle(cond, a_field_data, b_data) + + with self.subTest("Test module function: numeric field and single numeric value"): + if callable(cond): + with self.assertRaises(NotImplementedError) as context: + result = fields.where(cond, a_field, b_data) + self.assertEqual(str(context.exception), "fields.where doesn't support callable cond") + else: + result = fields.where(cond, a_field, b_data) + np.testing.assert_array_equal(expected_result, result) + + if b_dtype is not None: + b_field = self.setup_field(self.df, "create_numeric", "bf", (b_dtype,), {}, b_data) + + with self.subTest("Test module function: numeric field and single numeric value"): + if callable(cond): + with self.assertRaises(NotImplementedError) as context: + result = fields.where(cond, a_field, b_field) + self.assertEqual(str(context.exception), "fields.where doesn't support callable cond") + else: + result = fields.where(cond, a_field, b_field) + np.testing.assert_array_equal(expected_result, result) + + + @parameterized.expand(INSTANCE_WHERE_NUMERIC_TESTS) + def test_instance_field_where(self, cond, a_field_data, a_field_dtype, b_data, b_dtype): + a_field = self.setup_field(self.df, "create_numeric", "af", (a_field_dtype,), {}, a_field_data) + + expected_result = where_oracle(cond, a_field_data, b_data) + + with self.subTest("Test field method: numeric field and single numeric value"): + result = a_field.where(cond, b_data) + # self.assertEqual(result.dtype, ) + np.testing.assert_array_equal(expected_result, result) + + + +# (1) data type of return result, if it's + + # def _test_module_where(self, create_field_fn, predicate, if_true, if_false, expected): + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = create_field_fn(df, 'foo') + + # with self.subTest("testing module where"): + # r = fields.where(predicate(f), if_true, if_false) + # self.assertEqual(r.tolist(), expected) + + # def _test_instance_where(self, create_field_fn, predicate, if_false, expected): + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = create_field_fn(df, 'foo') + + # with self.subTest("testing field where"): + # r = f.where(predicate(f), if_false) + # self.assertEqual(r.tolist(), expected) + # with self.subTest("testing field where with predicate"): + # r = f.where(lambda f2: predicate(f2), if_false) + # self.assertEqual(r.tolist(), expected) + # with self.subTest("testing inplace field where"): + # r = f.where(predicate(f), if_false, inplace=True) + # self.assertEqual(list(f.data[:]), expected) + + # def test_field_where_numeric_int32(self): + # def create_numeric(df, name): + # f = df.create_numeric(name, 'int32') + # f.data.write(np.asarray([-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], dtype=np.int32)) + # return f + + # self._test_module_where(create_numeric, lambda f: f > 5, 1, 0, + # [0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0]) + # self._test_module_where(create_numeric, lambda f: f > 5, 1, 9, + # [9, 9, 9, 9, 1, 9, 1, 1, 9, 1, 9]) + # self._test_module_where(create_numeric, lambda f: f > 5, 1, -1, + # [-1, -1, -1, -1, 1, -1, 1, 1, -1, 1, -1]) + + # self._test_instance_where(create_numeric, lambda f: f > 5, 0, + # [0, 0, 0, 0, 9, 0, 8, 6, 0, 7, 0]) + # self._test_instance_where(create_numeric, lambda f: f > 5, 9, + # [9, 9, 9, 9, 9, 9, 8, 6, 9, 7, 9]) + # self._test_instance_where(create_numeric, lambda f: f > 5, -1, + # [-1, -1, -1, -1, 9, -1, 8, 6, -1, 7, -1]) + + # def test_field_where_numeric_float32(self): + # def create_numeric(df, name): + # f = df.create_numeric(name, 'float32') + # f.data.write(np.asarray([1e-7, 0.24, 1873.0, -0.0088, 227819.38457, np.nan, 0.0], + # dtype=np.float32)) + # return f + + # self._test_module_where(create_numeric, lambda f: np.isnan(f.data[:]), 1, 0, + # [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]) + # self._test_module_where(create_numeric, lambda f: f > 5, 1, 9, + # [9, 9, 9, 9, 1, 9, 1, 1, 9, 1, 9]) + # self._test_module_where(create_numeric, lambda f: f > 5, 1, -1, + # [-1, -1, -1, -1, 1, -1, 1, 1, -1, 1, -1]) + # + # self._test_instance_where(create_numeric, lambda f: f > 5, 0, + # [0, 0, 0, 0, 9, 0, 8, 6, 0, 7, 0]) + # self._test_instance_where(create_numeric, lambda f: f > 5, 9, + # [9, 9, 9, 9, 9, 9, 8, 6, 9, 7, 9]) + # self._test_instance_where(create_numeric, lambda f: f > 5, -1, + # [-1, -1, -1, -1, 9, -1, 8, 6, -1, 7, -1]) + + # def test_field_where_categorical(self): + # def create_categorical(df, name): + # f = df.create_categorical(name, nformat='int8', key={'x': 0, 'y': 1, 'xy': 2}) + # f.data.write(np.asarray([0, 2, 1, 2, 0, 1, 1, 2, 1])) + # return f + + # self._test_module_where(create_categorical, lambda f: f == 2, 1, 0, + # [0, 1, 0, 1, 0, 0, 0, 1, 0]) + + # self._test_instance_where(create_categorical, lambda f: f != 2, -1, + # [0, -1, 1, -1, 0, 1, 1, -1, 1]) + + # def test_field_where_fixed_string(self): + # def create_fixed_string(df, name): + # f = df.create_fixed_string(name, 6) + # f.data.write(np.asarray(['foo', '"foo"', '', 'bar', 'barn', 'bat'], dtype='S6')) + # return f + + # self._test_module_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3, + # 'boo', '_far', + # ['_far', 'boo', '_far', '_far', 'boo', '_far']) + # # [b'_far', b'boo', b'_far', b'_far', b'boo', b'_far']) + + # self._test_instance_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3, + # 'foobar', + # [b'foobar', b'"foo"', b'foobar', b'foobar', b'barn', b'foobar']) + + + # def test_field_where_indexed_string(self): + # def create_indexed_string(df, name): + # f = df.create_indexed_string(name) + # f.data.write(['foo', '"foo"', '', 'bar', 'barn', 'bat']) + # return f + + # self._test_module_where(create_indexed_string, lambda f: np.char.str_len(f.data[:]) > 3, + # 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far']) + + # self._test_module_where(create_indexed_string, lambda f: (f.indices[1:] - f.indices[:-1]) > 3, + # 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far']) + + + # def test_module_where_numeric(self): + # input_data = [1, 2, 3, 5, 9, 8, 6, 4, 7, 0] + # data = np.asarray(input_data, dtype=np.int32) + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = df.create_numeric('foo', 'int32') + # f.data.write(data) + # + # r = fields.where(f > 5, 1, 0) + # self.assertEqual(r.tolist(), [0,0,0,0,1,1,1,0,1,0]) + # + # def test_instance_where_numeric(self): + # input_data = [1,2,3,5,9,8,6,4,7,0] + # data = np.asarray(input_data, dtype=np.int32) + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = df.create_numeric('foo', 'int32') + # f.data.write(data) + # r = f.where(f > 5, 0) + # self.assertEqual(r.tolist(), [0,0,0,0,9,8,6,0,7,0]) + # + # def test_instance_where_numeric_inplace(self): + # input_data = [1,2,3,5,9,8,6,4,7,0] + # data = np.asarray(input_data, dtype=np.int32) + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = df.create_numeric('foo', 'int32') + # f.data.write(data) + # + # r = f.where(f > 5, 0) + # self.assertEqual(list(f.data[:]), [1,2,3,5,9,8,6,4,7,0]) + # r = f.where(f > 5, 0, inplace=True) + # self.assertEqual(list(f.data[:]), [0,0,0,0,9,8,6,0,7,0]) + # + # def test_instance_where_with_callable(self): + # input_data = [1,2,3,5,9,8,6,4,7,0] + # data = np.asarray(input_data, dtype=np.int32) + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = df.create_numeric('foo', 'int32') + # f.data.write(data) + # + # r = f.where(lambda x: x > 5, 0) + # self.assertEqual(r.tolist(), [0,0,0,0,9,8,6,0,7,0]) + + # def test_where_bool_condition(self): + # input_data = [1,2,3,5,9,8,6,4,7,0] + # data = np.asarray(input_data, dtype=np.int32) + # bio = BytesIO() + # with session.Session() as s: + # src = s.open_dataset(bio, 'w', 'src') + # df = src.create_dataframe('df') + # f = df.create_numeric('foo', 'int32') + # f.data.write(data) + + # cond = np.array([False,False,True,True, False, True, True, False, False, True]) + # r = f.where(cond, 0) + # self.assertEqual(r.tolist(), [0,0,3,5,0,8,6,0,0,0]) + + class TestFieldModuleFunctions(SessionTestCase): @parameterized.expand(DEFAULT_FIELD_DATA) From 834b5a99c06f0ae8a846b7755a08a16cb7ff200e Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Fri, 20 May 2022 18:54:52 +0100 Subject: [PATCH 02/14] implement where return memfield --- exetera/core/fields.py | 45 +++++--- tests/test_fields.py | 230 +++++++++-------------------------------- 2 files changed, 78 insertions(+), 197 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index f7c6c18..4cd0ff6 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -20,6 +20,7 @@ from exetera.core.data_writer import DataWriter from exetera.core import operations as ops from exetera.core import validation as val +from exetera.core import utils def isin(field:Field, test_elements:Union[list, set, np.ndarray]): @@ -39,15 +40,15 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]): return ret -def where(cond, a, b): - if isinstance(cond, list) or (isinstance(cond, np.ndarray) and cond.dtype == 'bool'): +def where(cond: Union[list, tuple, np.ndarray, Field], a, b): + if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): if cond.indexed: - raise NotImplementedError("Where does not support indexed string fields at present") + raise NotImplementedError("Where does not support condition on indexed string fields at present") cond = cond.data[:] elif callable(cond): - raise NotImplementedError("fields.where doesn't support callable cond") + raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") if isinstance(a, Field): a = a.data[:] @@ -160,30 +161,40 @@ def _ensure_valid(self): if not self._valid_reference: raise ValueError("This field no longer refers to a valid underlying field object") - def where(self, cond, b, inplace=False): - - if callable(cond): - cond = cond(self.data[:]) - elif isinstance(cond, list) or (isinstance(cond, np.ndarray) and cond.dtype == 'bool'): + def where(self, cond:Union[list, tuple, np.ndarray, Field], b, inplace=False): + if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): if cond.indexed: raise NotImplementedError("Where does not support indexed string fields at present") cond = cond.data[:] + elif callable(cond): + cond = cond(self.data[:]) else: - raise TypeError("'cond' parameter needs to be either callable lambda function, or boolean ndarray, or NumericMemField") + raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField") - if isinstance(b, str): - b = b.encode() + # if isinstance(b, str): + # b = b.encode() if isinstance(b, Field): b = b.data[:] - result = np.where(cond, self.data[:], b) + result_ndarray = np.where(cond, self.data[:], b) + result_mem_field = None + if str(result_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES: + result_mem_field = NumericMemField(self._session, str(result_ndarray.dtype)) + result_mem_field.data.write(result_ndarray) - if inplace: - self.data.clear() - self.data.write(result) - return result + elif isinstance(self, (IndexedStringField, FixedStringField)) or isinstance(b, (IndexedStringField, FixedStringField)): + result_mem_field = IndexedStringMemField(self._session) + result_mem_field.data.write(result_ndarray) + else: + raise NotImplementedError(f"instance method where doesn't support the current input type") + + # if inplace: + # self.data.clear() + # self.data.write(result) + + return result_mem_field class MemoryField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index b45d420..33de873 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2169,33 +2169,25 @@ def test_indexed_string_isin(self, data, isin_data, expected): np.testing.assert_array_equal(expected, result) -MODULE_WHERE_TESTS = [ - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 0, None), - ( - [False, False, False, False, True, False, True, True, False, True, False], - [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', - 9, None - ), - ( - [False, False, False, False, True, False, True, True, False, True, False], - [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'int64', - ), -] +WHERE_NUMERIC_TESTS = [ + (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10))), None, None, 0, 'int8'), + (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, -1.0, 'float64'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), None, None, -1.0, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, shuffle_randstate(list(range(0,20))), 'int64'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float64'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))),"create_numeric", {"nformat": "float64"}, shuffle_randstate(list(range(-10,10))), 'float64'), + (RAND_STATE.randint(0, 2, 20).tolist(), "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(),"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'int32'), -INSTANCE_WHERE_NUMERIC_TESTS = [ - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 0, None), - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', 9, None), - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', -1, None), - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'int32'), - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], 'int64'), - (lambda f: f > 5, [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], 'int32', [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], 'float32'), - # ( - # [False, False, False, False, True, False, True, True, False, True, False], - # [-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], +] - # ) +WHERE_INDEXED_STRING_TESTS = [ + (lambda f: f > 5, ['a', 'b', 'c'], [1,2,3]), ] def where_oracle(cond, a, b): @@ -2211,133 +2203,62 @@ def where_oracle(cond, a, b): class TestFieldWhereFunctions(SessionTestCase): - @parameterized.expand(MODULE_WHERE_TESTS) - def test_module_field_where(self, cond, a_field_data, a_field_dtype, b_data, b_dtype): + @parameterized.expand(WHERE_NUMERIC_TESTS) + def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): """ Test `where` for the numeric fields using `fields.where` function and the object's method. """ - a_field = self.setup_field(self.df, "create_numeric", "af", (a_field_dtype,), {}, a_field_data) - - + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) expected_result = where_oracle(cond, a_field_data, b_data) - with self.subTest("Test module function: numeric field and single numeric value"): - if callable(cond): - with self.assertRaises(NotImplementedError) as context: + if b_kwarg is None: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): + if callable(cond): + with self.assertRaises(NotImplementedError) as context: + result = fields.where(cond, a_field, b_data) + self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") + else: result = fields.where(cond, a_field, b_data) - self.assertEqual(str(context.exception), "fields.where doesn't support callable cond") - else: - result = fields.where(cond, a_field, b_data) - np.testing.assert_array_equal(expected_result, result) - - if b_dtype is not None: - b_field = self.setup_field(self.df, "create_numeric", "bf", (b_dtype,), {}, b_data) + np.testing.assert_array_equal(expected_result, result) - with self.subTest("Test module function: numeric field and single numeric value"): + else: + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): if callable(cond): with self.assertRaises(NotImplementedError) as context: result = fields.where(cond, a_field, b_field) - self.assertEqual(str(context.exception), "fields.where doesn't support callable cond") + self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") else: result = fields.where(cond, a_field, b_field) np.testing.assert_array_equal(expected_result, result) - @parameterized.expand(INSTANCE_WHERE_NUMERIC_TESTS) - def test_instance_field_where(self, cond, a_field_data, a_field_dtype, b_data, b_dtype): - a_field = self.setup_field(self.df, "create_numeric", "af", (a_field_dtype,), {}, a_field_data) + @parameterized.expand(WHERE_NUMERIC_TESTS) + def test_instance_field_where_return_numericmemfield(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) expected_result = where_oracle(cond, a_field_data, b_data) - with self.subTest("Test field method: numeric field and single numeric value"): - result = a_field.where(cond, b_data) - # self.assertEqual(result.dtype, ) - np.testing.assert_array_equal(expected_result, result) - - + if b_kwarg is None: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): + result = a_field.where(cond, b_data) + self.assertEqual(result._nformat, expected_dtype) + np.testing.assert_array_equal(result, expected_result) -# (1) data type of return result, if it's - - # def _test_module_where(self, create_field_fn, predicate, if_true, if_false, expected): - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = create_field_fn(df, 'foo') - - # with self.subTest("testing module where"): - # r = fields.where(predicate(f), if_true, if_false) - # self.assertEqual(r.tolist(), expected) - - # def _test_instance_where(self, create_field_fn, predicate, if_false, expected): - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = create_field_fn(df, 'foo') - - # with self.subTest("testing field where"): - # r = f.where(predicate(f), if_false) - # self.assertEqual(r.tolist(), expected) - # with self.subTest("testing field where with predicate"): - # r = f.where(lambda f2: predicate(f2), if_false) - # self.assertEqual(r.tolist(), expected) - # with self.subTest("testing inplace field where"): - # r = f.where(predicate(f), if_false, inplace=True) - # self.assertEqual(list(f.data[:]), expected) - - # def test_field_where_numeric_int32(self): - # def create_numeric(df, name): - # f = df.create_numeric(name, 'int32') - # f.data.write(np.asarray([-1, 2, 3, 5, 9, 5, 8, 6, 4, 7, 0], dtype=np.int32)) - # return f - - # self._test_module_where(create_numeric, lambda f: f > 5, 1, 0, - # [0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0]) - # self._test_module_where(create_numeric, lambda f: f > 5, 1, 9, - # [9, 9, 9, 9, 1, 9, 1, 1, 9, 1, 9]) - # self._test_module_where(create_numeric, lambda f: f > 5, 1, -1, - # [-1, -1, -1, -1, 1, -1, 1, 1, -1, 1, -1]) - - # self._test_instance_where(create_numeric, lambda f: f > 5, 0, - # [0, 0, 0, 0, 9, 0, 8, 6, 0, 7, 0]) - # self._test_instance_where(create_numeric, lambda f: f > 5, 9, - # [9, 9, 9, 9, 9, 9, 8, 6, 9, 7, 9]) - # self._test_instance_where(create_numeric, lambda f: f > 5, -1, - # [-1, -1, -1, -1, 9, -1, 8, 6, -1, 7, -1]) - - # def test_field_where_numeric_float32(self): - # def create_numeric(df, name): - # f = df.create_numeric(name, 'float32') - # f.data.write(np.asarray([1e-7, 0.24, 1873.0, -0.0088, 227819.38457, np.nan, 0.0], - # dtype=np.float32)) - # return f + else: + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) - # self._test_module_where(create_numeric, lambda f: np.isnan(f.data[:]), 1, 0, - # [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]) - # self._test_module_where(create_numeric, lambda f: f > 5, 1, 9, - # [9, 9, 9, 9, 1, 9, 1, 1, 9, 1, 9]) - # self._test_module_where(create_numeric, lambda f: f > 5, 1, -1, - # [-1, -1, -1, -1, 1, -1, 1, 1, -1, 1, -1]) - # - # self._test_instance_where(create_numeric, lambda f: f > 5, 0, - # [0, 0, 0, 0, 9, 0, 8, 6, 0, 7, 0]) - # self._test_instance_where(create_numeric, lambda f: f > 5, 9, - # [9, 9, 9, 9, 9, 9, 8, 6, 9, 7, 9]) - # self._test_instance_where(create_numeric, lambda f: f > 5, -1, - # [-1, -1, -1, -1, 9, -1, 8, 6, -1, 7, -1]) - - # def test_field_where_categorical(self): - # def create_categorical(df, name): - # f = df.create_categorical(name, nformat='int8', key={'x': 0, 'y': 1, 'xy': 2}) - # f.data.write(np.asarray([0, 2, 1, 2, 0, 1, 1, 2, 1])) - # return f + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a_field.where(cond, b_field) + self.assertIsInstance(result, fields.NumericMemField) + self.assertEqual(result._nformat, expected_dtype) + np.testing.assert_array_equal(result, expected_result) - # self._test_module_where(create_categorical, lambda f: f == 2, 1, 0, - # [0, 1, 0, 1, 0, 0, 0, 1, 0]) - # self._test_instance_where(create_categorical, lambda f: f != 2, -1, - # [0, -1, 1, -1, 0, 1, 1, -1, 1]) + @parameterized.expand(WHERE_INDEXED_STRING_TESTS) + def test_instance_field_where_return_numericmemfield(self, cond, a, b): + pass + # def test_field_where_fixed_string(self): # def create_fixed_string(df, name): @@ -2367,31 +2288,6 @@ def test_instance_field_where(self, cond, a_field_data, a_field_dtype, b_data, b # self._test_module_where(create_indexed_string, lambda f: (f.indices[1:] - f.indices[:-1]) > 3, # 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far']) - - # def test_module_where_numeric(self): - # input_data = [1, 2, 3, 5, 9, 8, 6, 4, 7, 0] - # data = np.asarray(input_data, dtype=np.int32) - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = df.create_numeric('foo', 'int32') - # f.data.write(data) - # - # r = fields.where(f > 5, 1, 0) - # self.assertEqual(r.tolist(), [0,0,0,0,1,1,1,0,1,0]) - # - # def test_instance_where_numeric(self): - # input_data = [1,2,3,5,9,8,6,4,7,0] - # data = np.asarray(input_data, dtype=np.int32) - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = df.create_numeric('foo', 'int32') - # f.data.write(data) - # r = f.where(f > 5, 0) - # self.assertEqual(r.tolist(), [0,0,0,0,9,8,6,0,7,0]) # # def test_instance_where_numeric_inplace(self): # input_data = [1,2,3,5,9,8,6,4,7,0] @@ -2408,32 +2304,6 @@ def test_instance_field_where(self, cond, a_field_data, a_field_dtype, b_data, b # r = f.where(f > 5, 0, inplace=True) # self.assertEqual(list(f.data[:]), [0,0,0,0,9,8,6,0,7,0]) # - # def test_instance_where_with_callable(self): - # input_data = [1,2,3,5,9,8,6,4,7,0] - # data = np.asarray(input_data, dtype=np.int32) - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = df.create_numeric('foo', 'int32') - # f.data.write(data) - # - # r = f.where(lambda x: x > 5, 0) - # self.assertEqual(r.tolist(), [0,0,0,0,9,8,6,0,7,0]) - - # def test_where_bool_condition(self): - # input_data = [1,2,3,5,9,8,6,4,7,0] - # data = np.asarray(input_data, dtype=np.int32) - # bio = BytesIO() - # with session.Session() as s: - # src = s.open_dataset(bio, 'w', 'src') - # df = src.create_dataframe('df') - # f = df.create_numeric('foo', 'int32') - # f.data.write(data) - - # cond = np.array([False,False,True,True, False, True, True, False, False, True]) - # r = f.where(cond, 0) - # self.assertEqual(r.tolist(), [0,0,3,5,0,8,6,0,0,0]) class TestFieldModuleFunctions(SessionTestCase): From fd79955188e0727d586b8b022734fef68d646e1e Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Thu, 9 Jun 2022 19:14:13 +0100 Subject: [PATCH 03/14] implement fixed string field and add parameterized unittest --- exetera/core/fields.py | 44 +++++++++++++++++++++++++++++------------- tests/test_fields.py | 23 ++++++++++++++++++++-- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 4cd0ff6..3c8ef95 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -15,6 +15,7 @@ import numpy as np import h5py +import re from exetera.core.abstract_types import Field from exetera.core.data_writer import DataWriter @@ -161,7 +162,7 @@ def _ensure_valid(self): if not self._valid_reference: raise ValueError("This field no longer refers to a valid underlying field object") - def where(self, cond:Union[list, tuple, np.ndarray, Field], b, inplace=False): + def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False): if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): @@ -173,22 +174,39 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field], b, inplace=False): else: raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField") - # if isinstance(b, str): - # b = b.encode() - if isinstance(b, Field): - b = b.data[:] - result_ndarray = np.where(cond, self.data[:], b) result_mem_field = None - if str(result_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES: - result_mem_field = NumericMemField(self._session, str(result_ndarray.dtype)) - result_mem_field.data.write(result_ndarray) - elif isinstance(self, (IndexedStringField, FixedStringField)) or isinstance(b, (IndexedStringField, FixedStringField)): - result_mem_field = IndexedStringMemField(self._session) - result_mem_field.data.write(result_ndarray) + if isinstance(self, IndexedStringField) or isinstance(b, IndexedStringField): + # TODO: return IndexedStringMemField + + # if isinstance(b, str): + # b = b.encode() + + pass else: - raise NotImplementedError(f"instance method where doesn't support the current input type") + b_data = b.data[:] if isinstance(b, Field) else b + + + result_ndarray = np.where(cond, self.data[:], b_data) + + + if isinstance(self, FixedStringField) or isinstance(b, FixedStringField): + length = 0 + result = re.findall(r" 5, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10))), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4), + (lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4), + (RAND_STATE.randint(0, 2, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist()), + (RAND_STATE.randint(0, 2, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10)))), +] WHERE_INDEXED_STRING_TESTS = [ (lambda f: f > 5, ['a', 'b', 'c'], [1,2,3]), @@ -2234,7 +2240,7 @@ def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_cre @parameterized.expand(WHERE_NUMERIC_TESTS) - def test_instance_field_where_return_numericmemfield(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): + def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) expected_result = where_oracle(cond, a_field_data, b_data) @@ -2255,8 +2261,21 @@ def test_instance_field_where_return_numericmemfield(self, cond, a_creator, a_kw np.testing.assert_array_equal(result, expected_result) + @parameterized.expand(WHERE_FIXED_STRING_TESTS) + def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + + expected_result = where_oracle(cond, a_field_data, b_data) + + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a_field.where(cond, b_field) + self.assertIsInstance(result, fields.FixedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) + + @parameterized.expand(WHERE_INDEXED_STRING_TESTS) - def test_instance_field_where_return_numericmemfield(self, cond, a, b): + def test_instance_field_where_return_indexed_string_mem_field(self, cond, a, b): pass From a5643059b3f856d65d47aaa67aa5d8c9b32e51f4 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Fri, 10 Jun 2022 18:07:53 +0100 Subject: [PATCH 04/14] implement where for two indexed string fields --- exetera/core/fields.py | 38 +++++++++++------ exetera/core/operations.py | 10 +++++ tests/test_fields.py | 83 +++++++++++++++++--------------------- 3 files changed, 71 insertions(+), 60 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 3c8ef95..1064e8d 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -177,34 +177,46 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace result_mem_field = None - if isinstance(self, IndexedStringField) or isinstance(b, IndexedStringField): - # TODO: return IndexedStringMemField + if isinstance(self, IndexedStringField) and isinstance(b, IndexedStringField): - # if isinstance(b, str): - # b = b.encode() + a_indices, a_values = self.indices[:], self.values[:] + b_indices, b_values = b.indices[:], b.values[:] + if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: + raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)") - pass - else: - b_data = b.data[:] if isinstance(b, Field) else b + r_indices = np.zeros(len(a_indices), dtype=np.int64) + r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) + + ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) + r_values = r_values[:r_indices[-1]] - result_ndarray = np.where(cond, self.data[:], b_data) + result_mem_field = IndexedStringMemField(self._session) + result_mem_field.indices.write(r_indices) + result_mem_field.values.write(r_values) + elif isinstance(self, IndexedStringField) or isinstance(b, IndexedStringField): + # TODO: return IndexedStringMemField + # operands could not be broadcast together with shapes (4,) (3,) (3,) + pass + else: + b_data = b.data[:] if isinstance(b, Field) else b + r_ndarray = np.where(cond, self.data[:], b_data) if isinstance(self, FixedStringField) or isinstance(b, FixedStringField): length = 0 - result = re.findall(r" 5, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10))), None, None, 0, 'int8'), - (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, -1.0, 'float64'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), None, None, -1.0, 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, shuffle_randstate(list(range(0,20))), 'int64'), - (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, shuffle_randstate(list(range(-10,10))), None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'float64'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))), 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, shuffle_randstate(list(range(-10,10))),"create_numeric", {"nformat": "float64"}, shuffle_randstate(list(range(-10,10))), 'float64'), - (RAND_STATE.randint(0, 2, 20).tolist(), "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(),"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), 'int32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), + (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), + (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), ] WHERE_FIXED_STRING_TESTS = [ - (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10))), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4), - (lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4), - (RAND_STATE.randint(0, 2, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, RAND_STATE.randint(1, 4, 20).tolist()), - (RAND_STATE.randint(0, 2, 20).tolist(), "create_fixed_string", {"length": 3}, [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4, "create_numeric", {"nformat": "int8"}, shuffle_randstate(list(range(-10,10)))), + (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 1}, [b"a", b"b", b"e", b"c", b" "]*4 ), ] WHERE_INDEXED_STRING_TESTS = [ - (lambda f: f > 5, ['a', 'b', 'c'], [1,2,3]), + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_indexed_string", {}, np.array(shuffle_randstate(WHERE_INDEXED_STRING_FIELD_DATA))), ] def where_oracle(cond, a, b): @@ -2262,11 +2269,11 @@ def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_ @parameterized.expand(WHERE_FIXED_STRING_TESTS) - def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data): + def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) - b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) - expected_result = where_oracle(cond, a_field_data, b_data) + expected_result = where_oracle(cond, a_field_data, b_field_data) with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): result = a_field.where(cond, b_field) @@ -2275,37 +2282,19 @@ def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creato @parameterized.expand(WHERE_INDEXED_STRING_TESTS) - def test_instance_field_where_return_indexed_string_mem_field(self, cond, a, b): - pass - - - # def test_field_where_fixed_string(self): - # def create_fixed_string(df, name): - # f = df.create_fixed_string(name, 6) - # f.data.write(np.asarray(['foo', '"foo"', '', 'bar', 'barn', 'bat'], dtype='S6')) - # return f - - # self._test_module_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3, - # 'boo', '_far', - # ['_far', 'boo', '_far', '_far', 'boo', '_far']) - # # [b'_far', b'boo', b'_far', b'_far', b'boo', b'_far']) - - # self._test_instance_where(create_fixed_string, lambda f: np.char.str_len(f.data[:]) > 3, - # 'foobar', - # [b'foobar', b'"foo"', b'foobar', b'foobar', b'barn', b'foobar']) + def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + expected_result = where_oracle(cond, a_field_data, b_field_data) - # def test_field_where_indexed_string(self): - # def create_indexed_string(df, name): - # f = df.create_indexed_string(name) - # f.data.write(['foo', '"foo"', '', 'bar', 'barn', 'bat']) - # return f + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a_field.where(cond, b_field) + self.assertIsInstance(result, fields.IndexedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) + - # self._test_module_where(create_indexed_string, lambda f: np.char.str_len(f.data[:]) > 3, - # 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far']) - # self._test_module_where(create_indexed_string, lambda f: (f.indices[1:] - f.indices[:-1]) > 3, - # 'boo', '_far', ['_far', 'boo', '_far', '_far', 'boo', '_far']) # # def test_instance_where_numeric_inplace(self): From fa62e3b32fe786d8233ee742e2d7b3f9c37753d2 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Mon, 13 Jun 2022 14:21:52 +0100 Subject: [PATCH 05/14] implement instance where when one field indexedstringfield, the other is not --- exetera/core/fields.py | 52 +++++++++++++++++++++++++++++++++++++----- tests/test_fields.py | 10 +++++--- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 1064e8d..40f9589 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -174,11 +174,9 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace else: raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField") - result_mem_field = None if isinstance(self, IndexedStringField) and isinstance(b, IndexedStringField): - a_indices, a_values = self.indices[:], self.values[:] b_indices, b_values = b.indices[:], b.values[:] if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: @@ -196,9 +194,51 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace result_mem_field.values.write(r_values) elif isinstance(self, IndexedStringField) or isinstance(b, IndexedStringField): - # TODO: return IndexedStringMemField - # operands could not be broadcast together with shapes (4,) (3,) (3,) - pass + indexed_str_field = self if isinstance(self, IndexedStringField) else b + other_field = b if isinstance(self, IndexedStringField) else self + + # check length + indexed_str_field_row_count = len(indexed_str_field.indices[:]) - 1 + other_field_row_count = len(other_field.data[:]) + if len(cond) != indexed_str_field_row_count or len(cond) != other_field_row_count: + raise ValueError(f"operands can't work with shapes ({len(cond)},) ({indexed_str_field_row_count},) ({other_field_row_count},)") + + # convert other field data to string array + data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) + maxLength = 0 + re_match = re.findall(r" Date: Tue, 14 Jun 2022 15:25:44 +0100 Subject: [PATCH 06/14] add IndexStringMemField check; move where logic to global --- exetera/core/fields.py | 203 ++++++++++++++++++++++------------------- tests/test_fields.py | 46 ++++++---- 2 files changed, 136 insertions(+), 113 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 40f9589..61b916b 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -51,11 +51,97 @@ def where(cond: Union[list, tuple, np.ndarray, Field], a, b): elif callable(cond): raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") - if isinstance(a, Field): - a = a.data[:] - if isinstance(b, Field): - b = b.data[:] - return np.where(cond, a, b) + return where_helper(cond, a, b) + + +def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field: + result_mem_field = None + + if isinstance(a, (IndexedStringField, IndexedStringMemField)) and isinstance(b, (IndexedStringField, IndexedStringMemField)): + a_indices, a_values = a.indices[:], a.values[:] + b_indices, b_values = b.indices[:], b.values[:] + if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: + raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)") + + r_indices = np.zeros(len(a_indices), dtype=np.int64) + r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) + + ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) + + r_values = r_values[:r_indices[-1]] + + result_mem_field = IndexedStringMemField(a._session) + result_mem_field.indices.write(r_indices) + result_mem_field.values.write(r_values) + + elif isinstance(a, (IndexedStringField, IndexedStringMemField)) or isinstance(b, (IndexedStringField, IndexedStringMemField)): + indexed_str_field = a if isinstance(a, (IndexedStringField, IndexedStringMemField)) else b + other_field = b if isinstance(a, (IndexedStringField, IndexedStringMemField)) else a + + # check length + indexed_str_field_row_count = len(indexed_str_field.indices[:]) - 1 + other_field_row_count = len(other_field.data[:]) + if len(cond) != indexed_str_field_row_count or len(cond) != other_field_row_count: + raise ValueError(f"operands can't work with shapes ({len(cond)},) ({indexed_str_field_row_count},) ({other_field_row_count},)") + + # convert other field data to string array + data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) + maxLength = 0 + re_match = re.findall(r" 5, 0) - # self.assertEqual(list(f.data[:]), [1,2,3,5,9,8,6,4,7,0]) - # r = f.where(f > 5, 0, inplace=True) - # self.assertEqual(list(f.data[:]), [0,0,0,0,9,8,6,0,7,0]) - # + # reload to test FixedStringMemField + a_mem_field, b_mem_field = a_field, b_field + if isinstance(a_field, fields.IndexedStringMemField): + a_mem_field = fields.IndexedStringMemField(self.s) + a_mem_field.data.write(a_field_data) + + if isinstance(b_field, fields.IndexedStringMemField): + b_mem_field = fields.IndexedStringMemField(self.s) + b_mem_field.data.write(b_field_data) + + with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"): + result = a_mem_field.where(cond, b_mem_field) + self.assertIsInstance(result, fields.IndexedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) class TestFieldModuleFunctions(SessionTestCase): From bd5151903469d037ee0c79a12fefe41b2b3cf078 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Wed, 15 Jun 2022 11:49:04 +0100 Subject: [PATCH 07/14] combine the logic of 'a&b is indexedstringfield' and 'a|b is indexedstringfield' --- exetera/core/fields.py | 59 +++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 61b916b..bee797c 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -55,36 +55,10 @@ def where(cond: Union[list, tuple, np.ndarray, Field], a, b): def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field: - result_mem_field = None - - if isinstance(a, (IndexedStringField, IndexedStringMemField)) and isinstance(b, (IndexedStringField, IndexedStringMemField)): - a_indices, a_values = a.indices[:], a.values[:] - b_indices, b_values = b.indices[:], b.values[:] - if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: - raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)") - - r_indices = np.zeros(len(a_indices), dtype=np.int64) - r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) - - ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) - - r_values = r_values[:r_indices[-1]] - - result_mem_field = IndexedStringMemField(a._session) - result_mem_field.indices.write(r_indices) - result_mem_field.values.write(r_values) - - elif isinstance(a, (IndexedStringField, IndexedStringMemField)) or isinstance(b, (IndexedStringField, IndexedStringMemField)): - indexed_str_field = a if isinstance(a, (IndexedStringField, IndexedStringMemField)) else b - other_field = b if isinstance(a, (IndexedStringField, IndexedStringMemField)) else a - - # check length - indexed_str_field_row_count = len(indexed_str_field.indices[:]) - 1 - other_field_row_count = len(other_field.data[:]) - if len(cond) != indexed_str_field_row_count or len(cond) != other_field_row_count: - raise ValueError(f"operands can't work with shapes ({len(cond)},) ({indexed_str_field_row_count},) ({other_field_row_count},)") + def get_indices_and_values_from_non_indexed_string_field(other_field): # convert other field data to string array + other_field_row_count = len(other_field.data[:]) data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) maxLength = 0 re_match = re.findall(r" Field: raise ValueError("The return dtype of instance method `where` doesn't match ' Date: Fri, 17 Jun 2022 11:55:07 +0100 Subject: [PATCH 08/14] add unittest when cond is field; add >U --- exetera/core/fields.py | 43 ++++++++++++++++++++++++------------------ tests/test_fields.py | 21 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index bee797c..6854ade 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -45,9 +45,10 @@ def where(cond: Union[list, tuple, np.ndarray, Field], a, b): if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): - if cond.indexed: - raise NotImplementedError("Where does not support condition on indexed string fields at present") - cond = cond.data[:] + if isinstance(cond, (NumericField, CategoricalField)): + cond = cond.data[:] + else: + raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") elif callable(cond): raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") @@ -61,9 +62,11 @@ def get_indices_and_values_from_non_indexed_string_field(other_field): other_field_row_count = len(other_field.data[:]) data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) maxLength = 0 - re_match = re.findall(r"U(\d+)|S(\d+)", str(data_converted_to_str.dtype)) if re_match: - maxLength = int(re_match[0][0]) if re_match[0][0] else int(re_match[0][1]) + for l in re_match[0]: + if l: + maxLength = int(l) else: raise ValueError("The return dtype of instance method `where` doesn't match 'U(\d+)|S(\d+)", str(r_ndarray.dtype)) + if re_match: + for l in re_match[0]: + if l: + maxLength = int(l) else: raise ValueError("The return dtype of instance method `where` doesn't match ' Date: Tue, 21 Jun 2022 15:52:13 +0100 Subject: [PATCH 09/14] field indexing --- exetera/core/fields.py | 13 ++++++++++++- tests/test_fields.py | 17 ++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 6854ade..61602a4 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Callable, ItemsView, Optional, Union from datetime import datetime, timezone import operator @@ -251,6 +251,17 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace return where_helper(cond, self, b) + def __getitem__(self, item:Union[list, tuple, np.ndarray]): + if isinstance(item, slice): + # TODO + pass + elif isinstance(item, int): + # TODO + pass + elif isinstance(item, (list, tuple, np.ndarray)): + filter_to_apply = np.array(item, dtype=np.int64) + # ? dstfld + self.apply_filter(filter_to_apply, dstfld=None) class MemoryField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index 486666c..1c911bb 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2188,7 +2188,6 @@ def test_indexed_string_isin(self, data, isin_data, expected): (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), - ] WHERE_FIXED_STRING_TESTS = [ @@ -2364,3 +2363,19 @@ def test_argsort(self, creator, name, kwargs, data): else: with self.assertRaises(ValueError): fields.argsort(f) + +INDEXING_TESTS = [ + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), + # (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}) +] + +class TestFieldIndexingFunctions(SessionTestCase): + @parameterized.expand(INDEXING_TESTS) + def test(self, filter, creator, kwargs, data): + f = self.setup_field(self.df, creator, 'f', (), kwargs, data) + f[filter] From 127511d1eebbd99aea3cbc5ebf4cbab825671410 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Mon, 27 Jun 2022 14:15:57 +0100 Subject: [PATCH 10/14] add mem field test --- tests/test_fields.py | 51 ++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index cbf2208..ff47fbf 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2275,25 +2275,25 @@ def test_indexed_string_isin(self, data, isin_data, expected): np.testing.assert_array_equal(expected, result) -WHERE_BOOLEAN_COND = RAND_STATE.randint(0, 2, 20).tolist() +WHERE_BOOLEAN_COND = RAND_STATE.randint(0, 2, 20) WHERE_NUMERIC_FIELD_DATA = shuffle_randstate(list(range(-10,10))) WHERE_FIXED_STRING_FIELD_DATA = [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4 -WHERE_CATEGORICAL_FIELD_DATA = RAND_STATE.randint(1, 4, 20).tolist() +WHERE_CATEGORICAL_FIELD_DATA = RAND_STATE.randint(1, 4, 20) WHERE_INDEXED_STRING_FIELD_DATA = (["a", "bb", "eeeee", "ccc", "dddd","", " "]*3)[:-1:] # make data length to 20 WHERE_NUMERIC_TESTS = [ - (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), - (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), - (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), + # (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), + # (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), + # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), + # (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), + # (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), + # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), - (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), - (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), - (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), + # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), + # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), + # (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), ] WHERE_FIXED_STRING_TESTS = [ @@ -2381,6 +2381,25 @@ def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_ self.assertEqual(result._nformat, expected_dtype) np.testing.assert_array_equal(result, expected_result) + # reload to test NumericMemField and CategoricalMemField + def reloadToMemField(field, field_data, kwarg): + if isinstance(field, fields.NumericField): + mem_field = fields.NumericMemField(self.s, kwarg['nformat']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + elif isinstance(field, fields.CategoricalField): + mem_field = fields.CategoricalMemField(self.s, kwarg['nformat'], kwarg['key']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + return mem_field + + a_mem_field = reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = reloadToMemField(b_field, b_data, b_kwarg) + + with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"): + result = a_mem_field.where(cond, b_mem_field) + self.assertIsInstance(result, fields.NumericMemField) + self.assertEqual(result._nformat, expected_dtype) + np.testing.assert_array_equal(result, expected_result) + @parameterized.expand(WHERE_FIXED_STRING_TESTS) def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): @@ -2422,13 +2441,13 @@ def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_crea self.assertIsInstance(result, fields.IndexedStringMemField) np.testing.assert_array_equal(result.data[:], expected_result) - # reload to test FixedStringMemField + # reload to test IndexedStringMemField a_mem_field, b_mem_field = a_field, b_field - if isinstance(a_field, fields.IndexedStringMemField): + if isinstance(a_field, fields.IndexedStringField): a_mem_field = fields.IndexedStringMemField(self.s) a_mem_field.data.write(a_field_data) - if isinstance(b_field, fields.IndexedStringMemField): + if isinstance(b_field, fields.IndexedStringField): b_mem_field = fields.IndexedStringMemField(self.s) b_mem_field.data.write(b_field_data) From c0746fd16f8a35aff36acc1d8b2660472d7a2fa6 Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Mon, 27 Jun 2022 15:37:25 +0100 Subject: [PATCH 11/14] move code change to another branch --- exetera/core/fields.py | 13 +------------ tests/test_fields.py | 40 ++++++++++++---------------------------- 2 files changed, 13 insertions(+), 40 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index b78df5b..dd75a71 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -250,18 +250,7 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField.") return where_helper(cond, self, b) - - def __getitem__(self, item:Union[list, tuple, np.ndarray]): - if isinstance(item, slice): - # TODO - pass - elif isinstance(item, int): - # TODO - pass - elif isinstance(item, (list, tuple, np.ndarray)): - filter_to_apply = np.array(item, dtype=np.int64) - # ? dstfld - self.apply_filter(filter_to_apply, dstfld=None) + class MemoryField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index ff47fbf..d7c968b 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2282,18 +2282,18 @@ def test_indexed_string_isin(self, data, isin_data, expected): WHERE_INDEXED_STRING_FIELD_DATA = (["a", "bb", "eeeee", "ccc", "dddd","", " "]*3)[:-1:] # make data length to 20 WHERE_NUMERIC_TESTS = [ - # (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), - # (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), - # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), - # (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), - # (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), - # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), + (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'), - # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), - # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), - # (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), - # (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), - # (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), + (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), ] WHERE_FIXED_STRING_TESTS = [ @@ -2487,20 +2487,4 @@ def test_argsort(self, creator, name, kwargs, data): self.assertListEqual(np.argsort(f.data[:]).tolist(), fields.argsort(f, dtype=kwargs['nformat']).data[:].tolist()) else: with self.assertRaises(ValueError): - fields.argsort(f) - -INDEXING_TESTS = [ - (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), - # (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}) -] - -class TestFieldIndexingFunctions(SessionTestCase): - @parameterized.expand(INDEXING_TESTS) - def test(self, filter, creator, kwargs, data): - f = self.setup_field(self.df, creator, 'f', (), kwargs, data) - f[filter] + fields.argsort(f) \ No newline at end of file From 78c7b950fd44b77d67eb958f11434a1665a5cdeb Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Wed, 29 Jun 2022 18:00:19 +0100 Subject: [PATCH 12/14] add combination of field and its memfield --- tests/test_fields.py | 151 ++++++++++++++++++++++--------------------- 1 file changed, 76 insertions(+), 75 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index d7c968b..da5a7e7 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2330,6 +2330,22 @@ def where_oracle(cond, a, b): class TestFieldWhereFunctions(SessionTestCase): + def reloadToMemField(self, field, field_data, kwarg): + if isinstance(field, fields.NumericField): + mem_field = fields.NumericMemField(self.s, kwarg['nformat']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + elif isinstance(field, fields.CategoricalField): + mem_field = fields.CategoricalMemField(self.s, kwarg['nformat'], kwarg['key']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + elif isinstance(field, fields.FixedStringField): + mem_field = fields.FixedStringMemField(self.s, kwarg["length"]) + mem_field.data.write(np.array(field_data)) + elif isinstance(field, fields.IndexedStringField): + mem_field = fields.IndexedStringMemField(self.s) + mem_field.data.write(field_data) + return mem_field + + @parameterized.expand(WHERE_NUMERIC_TESTS) def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): """ @@ -2363,116 +2379,101 @@ def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_cre @parameterized.expand(WHERE_NUMERIC_TESTS) def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) expected_result = where_oracle(cond, a_field_data, b_data) if b_kwarg is None: - with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): - result = a_field.where(cond, b_data) - self.assertEqual(result._nformat, expected_dtype) - np.testing.assert_array_equal(result, expected_result) + combinations = [(a_field, b_data), + (a_mem_field, b_data)] + + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): + result = a.where(cond, b) + self.assertEqual(result._nformat, expected_dtype) + np.testing.assert_array_equal(result, expected_result) else: b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + b_mem_field = self.reloadToMemField(b_field, b_data, b_kwarg) + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] - with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): - result = a_field.where(cond, b_field) - self.assertIsInstance(result, fields.NumericMemField) - self.assertEqual(result._nformat, expected_dtype) - np.testing.assert_array_equal(result, expected_result) - - # reload to test NumericMemField and CategoricalMemField - def reloadToMemField(field, field_data, kwarg): - if isinstance(field, fields.NumericField): - mem_field = fields.NumericMemField(self.s, kwarg['nformat']) - mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) - elif isinstance(field, fields.CategoricalField): - mem_field = fields.CategoricalMemField(self.s, kwarg['nformat'], kwarg['key']) - mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) - return mem_field - - a_mem_field = reloadToMemField(a_field, a_field_data, a_kwarg) - b_mem_field = reloadToMemField(b_field, b_data, b_kwarg) - - with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"): - result = a_mem_field.where(cond, b_mem_field) - self.assertIsInstance(result, fields.NumericMemField) - self.assertEqual(result._nformat, expected_dtype) - np.testing.assert_array_equal(result, expected_result) + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.NumericMemField) + self.assertEqual(result._nformat, expected_dtype) @parameterized.expand(WHERE_FIXED_STRING_TESTS) def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) expected_result = where_oracle(cond, a_field_data, b_field_data) - with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): - result = a_field.where(cond, b_field) - self.assertIsInstance(result, fields.FixedStringMemField) - np.testing.assert_array_equal(result.data[:], expected_result) + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] - # reload to test FixedStringMemField - a_mem_field, b_mem_field = a_field, b_field - if isinstance(a_field, fields.FixedStringField): - a_mem_field = fields.FixedStringMemField(self.s, a_kwarg["length"]) - a_mem_field.data.write(np.array(a_field_data)) - - if isinstance(b_field, fields.FixedStringField): - b_mem_field = fields.FixedStringMemField(self.s, b_kwarg["length"]) - b_mem_field.data.write(np.array(b_field_data)) - - with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"): - result = a_mem_field.where(cond, b_mem_field) - self.assertIsInstance(result, fields.FixedStringMemField) - np.testing.assert_array_equal(result.data[:], expected_result) + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.FixedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) @parameterized.expand(WHERE_INDEXED_STRING_TESTS) def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) expected_result = where_oracle(cond, a_field_data, b_field_data) - with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): - result = a_field.where(cond, b_field) - self.assertIsInstance(result, fields.IndexedStringMemField) - np.testing.assert_array_equal(result.data[:], expected_result) - - # reload to test IndexedStringMemField - a_mem_field, b_mem_field = a_field, b_field - if isinstance(a_field, fields.IndexedStringField): - a_mem_field = fields.IndexedStringMemField(self.s) - a_mem_field.data.write(a_field_data) - - if isinstance(b_field, fields.IndexedStringField): - b_mem_field = fields.IndexedStringMemField(self.s) - b_mem_field.data.write(b_field_data) - - with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"): - result = a_mem_field.where(cond, b_mem_field) - self.assertIsInstance(result, fields.IndexedStringMemField) - np.testing.assert_array_equal(result.data[:], expected_result) + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.IndexedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) + @parameterized.expand(WHERE_INDEXED_STRING_TESTS) def test_instance_field_where_with_cond_is_field(self, _, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) cond = a_field - with self.subTest(f"Test instance where method: cond is a is {type(a_field)}, a is {type(a_field)}, b is {type(b_field)}"): - if isinstance(cond, (fields.NumericField, fields.CategoricalField)): - result = a_field.where(cond, b_field) - self.assertIsInstance(result, fields.IndexedStringMemField) + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] - expected_result = where_oracle(cond, a_field_data, b_field_data) - np.testing.assert_array_equal(result.data[:], expected_result) - else: - with self.assertRaises(NotImplementedError) as context: - result = a_field.where(cond, b_field) - self.assertEqual(str(context.exception), "Where only support condition on numeric field and categorical field at present.") + for a, b in combinations: + with self.subTest(f"Test instance where method: cond is a is {type(a_field)}, a is {type(a_field)}, b is {type(b_field)}"): + if isinstance(cond, (fields.NumericField, fields.CategoricalField, fields.NumericMemField, fields.CategoricalMemField)): + result = a.where(cond, b) + self.assertIsInstance(result, fields.IndexedStringMemField) + + expected_result = where_oracle(cond, a_field_data, b_field_data) + np.testing.assert_array_equal(result.data[:], expected_result) + else: + with self.assertRaises(NotImplementedError) as context: + result = a.where(cond, b) + self.assertEqual(str(context.exception), "Where only support condition on numeric field and categorical field at present.") class TestFieldModuleFunctions(SessionTestCase): From 86fdeed18b466c3cba3ed582651cd73675a192db Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Mon, 18 Jul 2022 13:22:50 +0100 Subject: [PATCH 13/14] add mem field check; add float dtype check --- exetera/core/fields.py | 6 +++--- tests/test_fields.py | 17 +++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index dd75a71..af2e18f 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -45,7 +45,7 @@ def where(cond: Union[list, tuple, np.ndarray, Field], a, b): if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): - if isinstance(cond, (NumericField, CategoricalField)): + if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)): cond = cond.data[:] else: raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") @@ -240,7 +240,7 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): - if isinstance(cond, (NumericField, CategoricalField)): + if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)): cond = cond.data[:] else: raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") @@ -335,7 +335,7 @@ def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace if isinstance(cond, (list, tuple, np.ndarray)): cond = cond elif isinstance(cond, Field): - if isinstance(cond, (NumericField, CategoricalField)): + if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)): cond = cond.data[:] else: raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") diff --git a/tests/test_fields.py b/tests/test_fields.py index da5a7e7..8f13172 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2298,6 +2298,7 @@ def test_indexed_string_isin(self, data, isin_data, expected): WHERE_FIXED_STRING_TESTS = [ (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), (lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), @@ -2314,7 +2315,7 @@ def test_indexed_string_isin(self, data, isin_data, expected): (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), ] -def where_oracle(cond, a, b): +def where_oracle(cond, a, b, set_a_dtype = None, set_b_dtype=None): if callable(cond): if isinstance(a, fields.Field): cond = cond(a.data[:]) @@ -2322,9 +2323,14 @@ def where_oracle(cond, a, b): cond = cond(np.array(a)) elif isinstance(a, np.ndarray): cond = cond(a) - elif isinstance(cond, (fields.NumericField, fields.CategoricalField)): + elif isinstance(cond, (fields.NumericField, fields.NumericMemField, fields.CategoricalField, fields.CategoricalMemField)): cond = cond.data[:] + if set_a_dtype and isinstance(a, list): + a = np.array(a, dtype=set_a_dtype) + if set_b_dtype and isinstance(b, list): + b = np.array(b, dtype=set_b_dtype) + return np.where(cond, a, b) @@ -2415,12 +2421,15 @@ def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creato a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) - expected_result = where_oracle(cond, a_field_data, b_field_data) + set_a_dtype= a_kwarg['nformat'] if 'nformat' in a_kwarg else None + set_b_dtype= b_kwarg['nformat'] if 'nformat' in b_kwarg else None + expected_result = where_oracle(cond, a_field_data, b_field_data, set_a_dtype, set_b_dtype) combinations = [(a_field, b_field), (a_field, b_mem_field), (a_mem_field, b_field), - (a_mem_field, b_mem_field)] + (a_mem_field, b_mem_field) + ] for a, b in combinations: with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): From 0d51c111271fe7e5087679e2888c16f7a50135ce Mon Sep 17 00:00:00 2001 From: Liyuan Chen Date: Mon, 8 Aug 2022 13:50:07 +0100 Subject: [PATCH 14/14] fix exception handling messages --- exetera/core/fields.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/exetera/core/fields.py b/exetera/core/fields.py index af2e18f..d740777 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -48,9 +48,9 @@ def where(cond: Union[list, tuple, np.ndarray, Field], a, b): if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)): cond = cond.data[:] else: - raise NotImplementedError("Where only support condition on numeric field and categorical field at present.") + raise NotImplementedError("where only supports python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.") elif callable(cond): - raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") + raise NotImplementedError("module method fields.where doesn't support callable cond parameter, please use the instance method where if you need to use a callable cond parameter.") return where_helper(cond, a, b) @@ -68,7 +68,7 @@ def get_indices_and_values_from_non_indexed_string_field(other_field): if l: maxLength = int(l) else: - raise ValueError("The return dtype of instance method `where` doesn't match '