diff --git a/exetera/core/abstract_types.py b/exetera/core/abstract_types.py index 48bede9..6b725f9 100644 --- a/exetera/core/abstract_types.py +++ b/exetera/core/abstract_types.py @@ -85,6 +85,10 @@ def isin(self, test_elements:Union[list, set, np.ndarray]): def unique(self, return_index=False, return_inverse=False, return_counts=False): raise NotImplementedError() + @staticmethod + def where(cond, a, b): + raise NotImplementedError() + class Dataset(ABC): """ diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 0448091..d740777 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -9,17 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Callable, ItemsView, Optional, Union from datetime import datetime, timezone import operator import numpy as np import h5py +import re from exetera.core.abstract_types import Field from exetera.core.data_writer import DataWriter from exetera.core import operations as ops from exetera.core import validation as val +from exetera.core import utils def isin(field:Field, test_elements:Union[list, set, np.ndarray]): @@ -39,6 +41,97 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]): return ret +def where(cond: Union[list, tuple, np.ndarray, Field], a, b): + if isinstance(cond, (list, tuple, np.ndarray)): + cond = cond + elif isinstance(cond, Field): + if isinstance(cond, (NumericField, NumericMemField, CategoricalField, CategoricalMemField)): + cond = cond.data[:] + else: + raise NotImplementedError("where only supports python sequences, numpy ndarrays, and numeric field and categorical field types for the cond parameter at present.") + elif callable(cond): + raise NotImplementedError("module method fields.where doesn't support callable cond parameter, please use the instance method where if you need to use a callable cond parameter.") + + return where_helper(cond, a, b) + + +def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field: + + def get_indices_and_values_from_non_indexed_string_field(other_field): + # convert other field data to string array + other_field_row_count = len(other_field.data[:]) + data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) + maxLength = 0 + re_match = re.findall(r"U(\d+)|S(\d+)", str(data_converted_to_str.dtype)) + if re_match: + for l in re_match[0]: + if l: + maxLength = int(l) + else: + raise ValueError("The return dtype of instance method where doesn't match 'U(\d+)|S(\d+)", str(r_ndarray.dtype)) + if re_match: + for l in re_match[0]: + if l: + maxLength = int(l) + else: + raise ValueError("The return dtype of instance method where doesn't match ' 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), + (lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), + (lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), + (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), +] + +WHERE_FIXED_STRING_TESTS = [ + (lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 1}, [b"a", b"b", b"e", b"c", b" "]*4 ), +] + +WHERE_INDEXED_STRING_TESTS = [ + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_indexed_string", {}, shuffle_randstate(WHERE_INDEXED_STRING_FIELD_DATA)), + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), + (WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), +] + +def where_oracle(cond, a, b, set_a_dtype = None, set_b_dtype=None): + if callable(cond): + if isinstance(a, fields.Field): + cond = cond(a.data[:]) + elif isinstance(a, list): + cond = cond(np.array(a)) + elif isinstance(a, np.ndarray): + cond = cond(a) + elif isinstance(cond, (fields.NumericField, fields.NumericMemField, fields.CategoricalField, fields.CategoricalMemField)): + cond = cond.data[:] + + if set_a_dtype and isinstance(a, list): + a = np.array(a, dtype=set_a_dtype) + if set_b_dtype and isinstance(b, list): + b = np.array(b, dtype=set_b_dtype) + + return np.where(cond, a, b) + + +class TestFieldWhereFunctions(SessionTestCase): + + def reloadToMemField(self, field, field_data, kwarg): + if isinstance(field, fields.NumericField): + mem_field = fields.NumericMemField(self.s, kwarg['nformat']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + elif isinstance(field, fields.CategoricalField): + mem_field = fields.CategoricalMemField(self.s, kwarg['nformat'], kwarg['key']) + mem_field.data.write(np.array(field_data, dtype=kwarg['nformat'])) + elif isinstance(field, fields.FixedStringField): + mem_field = fields.FixedStringMemField(self.s, kwarg["length"]) + mem_field.data.write(np.array(field_data)) + elif isinstance(field, fields.IndexedStringField): + mem_field = fields.IndexedStringMemField(self.s) + mem_field.data.write(field_data) + return mem_field + + + @parameterized.expand(WHERE_NUMERIC_TESTS) + def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): + """ + Test `where` for the numeric fields using `fields.where` function and the object's method. + """ + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + expected_result = where_oracle(cond, a_field_data, b_data) + + if b_kwarg is None: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): + if callable(cond): + with self.assertRaises(NotImplementedError) as context: + result = fields.where(cond, a_field, b_data) + self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") + else: + result = fields.where(cond, a_field, b_data) + np.testing.assert_array_equal(expected_result, result) + + else: + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + if callable(cond): + with self.assertRaises(NotImplementedError) as context: + result = fields.where(cond, a_field, b_field) + self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") + else: + result = fields.where(cond, a_field, b_field) + np.testing.assert_array_equal(expected_result, result) + + + @parameterized.expand(WHERE_NUMERIC_TESTS) + def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + + expected_result = where_oracle(cond, a_field_data, b_data) + + if b_kwarg is None: + combinations = [(a_field, b_data), + (a_mem_field, b_data)] + + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): + result = a.where(cond, b) + self.assertEqual(result._nformat, expected_dtype) + np.testing.assert_array_equal(result, expected_result) + + else: + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) + b_mem_field = self.reloadToMemField(b_field, b_data, b_kwarg) + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] + + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.NumericMemField) + self.assertEqual(result._nformat, expected_dtype) + + + @parameterized.expand(WHERE_FIXED_STRING_TESTS) + def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) + + set_a_dtype= a_kwarg['nformat'] if 'nformat' in a_kwarg else None + set_b_dtype= b_kwarg['nformat'] if 'nformat' in b_kwarg else None + expected_result = where_oracle(cond, a_field_data, b_field_data, set_a_dtype, set_b_dtype) + + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field) + ] + + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.FixedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) + + + @parameterized.expand(WHERE_INDEXED_STRING_TESTS) + def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) + + expected_result = where_oracle(cond, a_field_data, b_field_data) + + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] + for a, b in combinations: + with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): + result = a.where(cond, b) + self.assertIsInstance(result, fields.IndexedStringMemField) + np.testing.assert_array_equal(result.data[:], expected_result) + + + @parameterized.expand(WHERE_INDEXED_STRING_TESTS) + def test_instance_field_where_with_cond_is_field(self, _, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): + a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) + b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) + a_mem_field = self.reloadToMemField(a_field, a_field_data, a_kwarg) + b_mem_field = self.reloadToMemField(b_field, b_field_data, b_kwarg) + cond = a_field + + combinations = [(a_field, b_field), + (a_field, b_mem_field), + (a_mem_field, b_field), + (a_mem_field, b_mem_field)] + + for a, b in combinations: + with self.subTest(f"Test instance where method: cond is a is {type(a_field)}, a is {type(a_field)}, b is {type(b_field)}"): + if isinstance(cond, (fields.NumericField, fields.CategoricalField, fields.NumericMemField, fields.CategoricalMemField)): + result = a.where(cond, b) + self.assertIsInstance(result, fields.IndexedStringMemField) + + expected_result = where_oracle(cond, a_field_data, b_field_data) + np.testing.assert_array_equal(result.data[:], expected_result) + else: + with self.assertRaises(NotImplementedError) as context: + result = a.where(cond, b) + self.assertEqual(str(context.exception), "Where only support condition on numeric field and categorical field at present.") + + class TestFieldModuleFunctions(SessionTestCase): @parameterized.expand(DEFAULT_FIELD_DATA) @@ -2287,4 +2497,4 @@ def test_argsort(self, creator, name, kwargs, data): self.assertListEqual(np.argsort(f.data[:]).tolist(), fields.argsort(f, dtype=kwargs['nformat']).data[:].tolist()) else: with self.assertRaises(ValueError): - fields.argsort(f) + fields.argsort(f) \ No newline at end of file