-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement where api #298
base: master
Are you sure you want to change the base?
implement where api #298
Changes from 5 commits
aa7301a
834b5a9
fd79955
a564305
fa62e3b
dea5b9c
bd51519
6da082d
0f83acd
213ddac
127511d
c0746fd
78c7b95
86fdeed
0d51c11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,13 @@ | |
|
||
import numpy as np | ||
import h5py | ||
import re | ||
|
||
from exetera.core.abstract_types import Field | ||
from exetera.core.data_writer import DataWriter | ||
from exetera.core import operations as ops | ||
from exetera.core import validation as val | ||
from exetera.core import utils | ||
|
||
|
||
def isin(field:Field, test_elements:Union[list, set, np.ndarray]): | ||
|
@@ -39,6 +41,23 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]): | |
return ret | ||
|
||
|
||
def where(cond: Union[list, tuple, np.ndarray, Field], a, b): | ||
if isinstance(cond, (list, tuple, np.ndarray)): | ||
cond = cond | ||
elif isinstance(cond, Field): | ||
if cond.indexed: | ||
raise NotImplementedError("Where does not support condition on indexed string fields at present") | ||
cond = cond.data[:] | ||
elif callable(cond): | ||
raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo, please replace with: "module method |
||
|
||
if isinstance(a, Field): | ||
a = a.data[:] | ||
if isinstance(b, Field): | ||
b = b.data[:] | ||
return np.where(cond, a, b) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still returning a numpy array rather than a field There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The logic of module-level |
||
|
||
|
||
class HDF5Field(Field): | ||
def __init__(self, session, group, dataframe, write_enabled=False): | ||
super().__init__() | ||
|
@@ -143,6 +162,110 @@ def _ensure_valid(self): | |
if not self._valid_reference: | ||
raise ValueError("This field no longer refers to a valid underlying field object") | ||
|
||
def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False): | ||
if isinstance(cond, (list, tuple, np.ndarray)): | ||
cond = cond | ||
elif isinstance(cond, Field): | ||
if cond.indexed: | ||
raise NotImplementedError("Where does not support indexed string fields at present") | ||
cond = cond.data[:] | ||
elif callable(cond): | ||
cond = cond(self.data[:]) | ||
else: | ||
raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we could just do |
||
result_mem_field = None | ||
|
||
if isinstance(self, IndexedStringField) and isinstance(b, IndexedStringField): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When doing the type checking need to check that it's one of two types: |
||
a_indices, a_values = self.indices[:], self.values[:] | ||
b_indices, b_values = b.indices[:], b.values[:] | ||
if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1: | ||
raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)") | ||
|
||
r_indices = np.zeros(len(a_indices), dtype=np.int64) | ||
r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) | ||
|
||
ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) | ||
|
||
r_values = r_values[:r_indices[-1]] | ||
|
||
result_mem_field = IndexedStringMemField(self._session) | ||
result_mem_field.indices.write(r_indices) | ||
result_mem_field.values.write(r_values) | ||
|
||
elif isinstance(self, IndexedStringField) or isinstance(b, IndexedStringField): | ||
indexed_str_field = self if isinstance(self, IndexedStringField) else b | ||
other_field = b if isinstance(self, IndexedStringField) else self | ||
|
||
# check length | ||
indexed_str_field_row_count = len(indexed_str_field.indices[:]) - 1 | ||
other_field_row_count = len(other_field.data[:]) | ||
if len(cond) != indexed_str_field_row_count or len(cond) != other_field_row_count: | ||
raise ValueError(f"operands can't work with shapes ({len(cond)},) ({indexed_str_field_row_count},) ({other_field_row_count},)") | ||
|
||
# convert other field data to string array | ||
data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count) | ||
maxLength = 0 | ||
re_match = re.findall(r"<U(\d+)|S(\d+)", str(data_converted_to_str.dtype)) | ||
if re_match: | ||
maxLength = int(re_match[0][0]) if re_match[0][0] else int(re_match[0][1]) | ||
else: | ||
raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField") | ||
|
||
# convert other field string array to indices and values | ||
other_indices = np.zeros(other_field_row_count + 1, dtype=np.int64) | ||
other_values = np.zeros(np.int64(other_field_row_count*maxLength), dtype=np.uint8) | ||
for i, s in enumerate(data_converted_to_str): | ||
encoded_s = np.array(list(s), dtype='S1').view(np.uint8) | ||
other_indices[i + 1] = other_indices[i] + len(encoded_s) | ||
other_values[other_indices[i]:other_indices[i + 1]] = encoded_s | ||
|
||
# assign self to a, b to b, according to a.where(cond, b) | ||
if isinstance(self, IndexedStringField): | ||
a_indices, a_values = indexed_str_field.indices[:], indexed_str_field.values[:] | ||
b_indices, b_values = other_indices, other_values | ||
else: | ||
a_indices, a_values = other_indices, other_values | ||
b_indices, b_values = indexed_str_field.indices[:], indexed_str_field.values[:] | ||
|
||
# get indices and values for result | ||
r_indices = np.zeros(len(a_indices), dtype=np.int64) | ||
r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8) | ||
ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values) | ||
r_values = r_values[:r_indices[-1]] | ||
|
||
# return IndexStringMemField | ||
result_mem_field = IndexedStringMemField(self._session) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem right. Why are we causing an operation with fixed string field to output an indexed string field? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For FixedStringField, you can refer to the matrix I listed above. Only when two FixedStringField will generate FixedStringField, otherwise it will be IndexedStringField. |
||
result_mem_field.indices.write(r_indices) | ||
result_mem_field.values.write(r_values) | ||
|
||
else: | ||
b_data = b.data[:] if isinstance(b, Field) else b | ||
r_ndarray = np.where(cond, self.data[:], b_data) | ||
|
||
if isinstance(self, FixedStringField) or isinstance(b, FixedStringField): | ||
length = 0 | ||
result = re.findall(r"<U(\d+)|S(\d+)", str(r_ndarray.dtype)) | ||
if result: | ||
length = int(result[0][0]) if result[0][0] else int(result[0][1]) | ||
else: | ||
raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField") | ||
|
||
result_mem_field = FixedStringMemField(self._session, length) | ||
result_mem_field.data.write(r_ndarray) | ||
|
||
elif str(r_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES: | ||
result_mem_field = NumericMemField(self._session, str(r_ndarray.dtype)) | ||
result_mem_field.data.write(r_ndarray) | ||
else: | ||
raise NotImplementedError(f"instance method `where` doesn't support the current input type: {type(self)} and {type(b)}") | ||
|
||
# if inplace: | ||
# self.data.clear() | ||
# self.data.write(result) | ||
|
||
return result_mem_field | ||
|
||
|
||
class MemoryField(Field): | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2169,6 +2169,155 @@ def test_indexed_string_isin(self, data, isin_data, expected): | |
np.testing.assert_array_equal(expected, result) | ||
|
||
|
||
WHERE_BOOLEAN_COND = RAND_STATE.randint(0, 2, 20).tolist() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we missing tests for when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, currently unittest for cond is a field is missing. I'm trying to add one. |
||
WHERE_NUMERIC_FIELD_DATA = shuffle_randstate(list(range(-10,10))) | ||
WHERE_FIXED_STRING_FIELD_DATA = [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4 | ||
WHERE_CATEGORICAL_FIELD_DATA = RAND_STATE.randint(1, 4, 20).tolist() | ||
WHERE_INDEXED_STRING_FIELD_DATA = (["a", "bb", "eeeee", "ccc", "dddd","", " "]*3)[:-1:] # make data length to 20 | ||
|
||
WHERE_NUMERIC_TESTS = [ | ||
(lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'), | ||
(lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'), | ||
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'), | ||
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), | ||
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'), | ||
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'), | ||
(WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'), | ||
|
||
] | ||
|
||
WHERE_FIXED_STRING_TESTS = [ | ||
(lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), | ||
(lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2300: can we also do this for float32? |
||
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 1}, [b"a", b"b", b"e", b"c", b" "]*4 ), | ||
] | ||
|
||
WHERE_INDEXED_STRING_TESTS = [ | ||
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_indexed_string", {}, shuffle_randstate(WHERE_INDEXED_STRING_FIELD_DATA)), | ||
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), | ||
(WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA), | ||
] | ||
|
||
def where_oracle(cond, a, b): | ||
if callable(cond): | ||
if isinstance(a, fields.Field): | ||
cond = cond(a.data[:]) | ||
elif isinstance(a, list): | ||
cond = cond(np.array(a)) | ||
elif isinstance(a, np.ndarray): | ||
cond = cond(a) | ||
return np.where(cond, a, b) | ||
|
||
|
||
class TestFieldWhereFunctions(SessionTestCase): | ||
|
||
@parameterized.expand(WHERE_NUMERIC_TESTS) | ||
def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): | ||
""" | ||
Test `where` for the numeric fields using `fields.where` function and the object's method. | ||
""" | ||
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) | ||
expected_result = where_oracle(cond, a_field_data, b_data) | ||
|
||
if b_kwarg is None: | ||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): | ||
if callable(cond): | ||
with self.assertRaises(NotImplementedError) as context: | ||
result = fields.where(cond, a_field, b_data) | ||
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") | ||
else: | ||
result = fields.where(cond, a_field, b_data) | ||
np.testing.assert_array_equal(expected_result, result) | ||
|
||
else: | ||
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) | ||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): | ||
if callable(cond): | ||
with self.assertRaises(NotImplementedError) as context: | ||
result = fields.where(cond, a_field, b_field) | ||
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.") | ||
else: | ||
result = fields.where(cond, a_field, b_field) | ||
np.testing.assert_array_equal(expected_result, result) | ||
|
||
|
||
@parameterized.expand(WHERE_NUMERIC_TESTS) | ||
def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype): | ||
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) | ||
|
||
expected_result = where_oracle(cond, a_field_data, b_data) | ||
|
||
if b_kwarg is None: | ||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"): | ||
result = a_field.where(cond, b_data) | ||
self.assertEqual(result._nformat, expected_dtype) | ||
np.testing.assert_array_equal(result, expected_result) | ||
|
||
else: | ||
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data) | ||
|
||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): | ||
result = a_field.where(cond, b_field) | ||
self.assertIsInstance(result, fields.NumericMemField) | ||
self.assertEqual(result._nformat, expected_dtype) | ||
np.testing.assert_array_equal(result, expected_result) | ||
|
||
|
||
@parameterized.expand(WHERE_FIXED_STRING_TESTS) | ||
def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): | ||
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) | ||
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) | ||
|
||
expected_result = where_oracle(cond, a_field_data, b_field_data) | ||
|
||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to after the mem fields are created |
||
result = a_field.where(cond, b_field) | ||
self.assertIsInstance(result, fields.FixedStringMemField) | ||
np.testing.assert_array_equal(result.data[:], expected_result) | ||
|
||
|
||
@parameterized.expand(WHERE_INDEXED_STRING_TESTS) | ||
def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here with combinations of hdf5 and mem fields |
||
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data) | ||
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data) | ||
|
||
expected_result = where_oracle(cond, a_field_data, b_field_data) | ||
|
||
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"): | ||
result = a_field.where(cond, b_field) | ||
self.assertIsInstance(result, fields.IndexedStringMemField) | ||
np.testing.assert_array_equal(result.data[:], expected_result) | ||
|
||
|
||
|
||
# def test_instance_where_numeric_inplace(self): | ||
# input_data = [1,2,3,5,9,8,6,4,7,0] | ||
# data = np.asarray(input_data, dtype=np.int32) | ||
# bio = BytesIO() | ||
# with session.Session() as s: | ||
# src = s.open_dataset(bio, 'w', 'src') | ||
# df = src.create_dataframe('df') | ||
# f = df.create_numeric('foo', 'int32') | ||
# f.data.write(data) | ||
# | ||
# r = f.where(f > 5, 0) | ||
# self.assertEqual(list(f.data[:]), [1,2,3,5,9,8,6,4,7,0]) | ||
# r = f.where(f > 5, 0, inplace=True) | ||
# self.assertEqual(list(f.data[:]), [0,0,0,0,9,8,6,0,7,0]) | ||
# | ||
|
||
|
||
class TestFieldModuleFunctions(SessionTestCase): | ||
|
||
@parameterized.expand(DEFAULT_FIELD_DATA) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: mehthod -> method