Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement where api #298

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exetera/core/abstract_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def isin(self, test_elements:Union[list, set, np.ndarray]):
def unique(self, return_index=False, return_inverse=False, return_counts=False):
raise NotImplementedError()

@staticmethod
def where(cond, a, b):
raise NotImplementedError()


class Dataset(ABC):
"""
Expand Down
117 changes: 117 additions & 0 deletions exetera/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import numpy as np
import h5py
import re

from exetera.core.abstract_types import Field
from exetera.core.data_writer import DataWriter
from exetera.core import operations as ops
from exetera.core import validation as val
from exetera.core import utils


def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
Expand All @@ -39,6 +41,92 @@ def isin(field:Field, test_elements:Union[list, set, np.ndarray]):
return ret


def where(cond: Union[list, tuple, np.ndarray, Field], a, b):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if cond.indexed:
raise NotImplementedError("Where does not support condition on indexed string fields at present")
cond = cond.data[:]
elif callable(cond):
raise NotImplementedError("module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: mehthod -> method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo, please replace with:

"module method fields.where doesn't support callable cond parameter, please use the instance method where if you need to use a callable cond parameter"


return where_helper(cond, a, b)


def where_helper(cond:Union[list, tuple, np.ndarray, Field], a, b) -> Field:

def get_indices_and_values_from_non_indexed_string_field(other_field):
# convert other field data to string array
other_field_row_count = len(other_field.data[:])
data_converted_to_str = np.where([True]*other_field_row_count, other_field.data[:], [""]*other_field_row_count)
maxLength = 0
re_match = re.findall(r"<U(\d+)|S(\d+)", str(data_converted_to_str.dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U can be <U or >U

if re_match:
maxLength = int(re_match[0][0]) if re_match[0][0] else int(re_match[0][1])
else:
raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField")

# convert other field string array to indices and values
indices = np.zeros(other_field_row_count + 1, dtype=np.int64)
values = np.zeros(np.int64(other_field_row_count*maxLength), dtype=np.uint8)
for i, s in enumerate(data_converted_to_str):
encoded_s = np.array(list(s), dtype='S1').view(np.uint8)
indices[i + 1] = indices[i] + len(encoded_s)
values[indices[i]:indices[i + 1]] = encoded_s
return indices, values

def get_indices_and_values_from_all_field(f):
if isinstance(f, (IndexedStringField, IndexedStringMemField)):
indices, values = f.indices[:], f.values[:]
else:
indices, values = get_indices_and_values_from_non_indexed_string_field(f)
return indices, values

result_mem_field = None

if isinstance(a, (IndexedStringField, IndexedStringMemField)) or isinstance(b, (IndexedStringField, IndexedStringMemField)):
a_indices, a_values = get_indices_and_values_from_all_field(a)
b_indices, b_values = get_indices_and_values_from_all_field(b)

if len(cond) != len(a_indices) - 1 or len(cond) != len(b_indices) - 1:
raise ValueError(f"operands can't work with shapes ({len(cond)},) ({len(a_indices) - 1},) ({len(b_indices) - 1},)")

# get indices and values for result
r_indices = np.zeros(len(a_indices), dtype=np.int64)
r_values = np.zeros(max(len(a_values), len(b_values)), dtype=np.uint8)
ops.where_for_two_indexed_string_fields(np.array(cond), a_indices, a_values, b_indices, b_values, r_indices, r_values)
r_values = r_values[:r_indices[-1]]

# return IndexStringMemField
result_mem_field = IndexedStringMemField(a._session)
result_mem_field.indices.write(r_indices)
result_mem_field.values.write(r_values)

else:
b_data = b.data[:] if isinstance(b, Field) else b
r_ndarray = np.where(cond, a.data[:], b_data)

if isinstance(a, (FixedStringField, FixedStringMemField)) or isinstance(b, (FixedStringField, FixedStringMemField)):
length = 0
result = re.findall(r"<U(\d+)|S(\d+)", str(r_ndarray.dtype))
if result:
length = int(result[0][0]) if result[0][0] else int(result[0][1])
else:
raise ValueError("The return dtype of instance method `where` doesn't match '<U(\d+)' or 'S(\d+)' when one of the field is FixedStringField")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo, please replace with:

"The return dtype of instance method where doesn't match '<U(\d+)' or 'S(\d+)' when one of the fields is a fixed string field"


result_mem_field = FixedStringMemField(a._session, length)
result_mem_field.data.write(r_ndarray)

elif str(r_ndarray.dtype) in utils.PERMITTED_NUMERIC_TYPES:
result_mem_field = NumericMemField(a._session, str(r_ndarray.dtype))
result_mem_field.data.write(r_ndarray)
else:
raise NotImplementedError(f"instance method `where` doesn't support the current input type: {type(a)} and {type(b)}")

return result_mem_field


class HDF5Field(Field):
def __init__(self, session, group, dataframe, write_enabled=False):
super().__init__()
Expand Down Expand Up @@ -143,6 +231,20 @@ def _ensure_valid(self):
if not self._valid_reference:
raise ValueError("This field no longer refers to a valid underlying field object")

def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if cond.indexed:
raise NotImplementedError("Where does not support indexed string fields at present")
cond = cond.data[:]
elif callable(cond):
cond = cond(self.data[:])
else:
raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we could just do return where(cond, self, b) and then the rest of the body of this method can be put into the global where function.

return where_helper(cond, self, b)


class MemoryField(Field):

Expand Down Expand Up @@ -223,6 +325,21 @@ def apply_index(self, index_to_apply, dstfld=None):
raise NotImplementedError("Please use apply_index() on specific fields, not the field base class.")


def where(self, cond:Union[list, tuple, np.ndarray, Field, Callable], b, inplace=False):
if isinstance(cond, (list, tuple, np.ndarray)):
cond = cond
elif isinstance(cond, Field):
if cond.indexed:
raise NotImplementedError("Where does not support indexed string fields at present")
cond = cond.data[:]
elif callable(cond):
cond = cond(self.data[:])
else:
raise TypeError("'cond' parameter needs to be either callable lambda function, or array like, or NumericMemField")

return where_helper(cond, self, b)


class ReadOnlyFieldArray:
def __init__(self, field, dataset_name):
self._field = field
Expand Down
10 changes: 10 additions & 0 deletions exetera/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,3 +3111,13 @@ def compare_arrays(a, b):
return 1
return 0


@exetera_njit
def where_for_two_indexed_string_fields(cond, a_indices, a_values, b_indices, b_values, r_indices, r_values):
for i, c in enumerate(cond):
if c:
r_indices[i + 1] = r_indices[i] + a_indices[i + 1] - a_indices[i]
r_values[r_indices[i]:r_indices[i + 1]] = a_values[a_indices[i]:a_indices[i + 1]]
else:
r_indices[i + 1] = r_indices[i] + b_indices[i + 1] - b_indices[i]
r_values[r_indices[i]:r_indices[i + 1]] = b_values[b_indices[i]:b_indices[i + 1]]
161 changes: 161 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,167 @@ def test_indexed_string_isin(self, data, isin_data, expected):
np.testing.assert_array_equal(expected, result)


WHERE_BOOLEAN_COND = RAND_STATE.randint(0, 2, 20).tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing tests for when cond is a field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, currently unittest for cond is a field is missing. I'm trying to add one.
So for the indexedstringfield, we will throw out the exception.
How should we deal with the FixedStringField? As we can't use string as boolean value directly, so which case should be considered True for fixedstringfield, and which case is False?

WHERE_NUMERIC_FIELD_DATA = shuffle_randstate(list(range(-10,10)))
WHERE_FIXED_STRING_FIELD_DATA = [b"aaa", b"bbb", b"eee", b"ccc", b" "]*4
WHERE_CATEGORICAL_FIELD_DATA = RAND_STATE.randint(1, 4, 20).tolist()
WHERE_INDEXED_STRING_FIELD_DATA = (["a", "bb", "eeeee", "ccc", "dddd","", " "]*3)[:-1:] # make data length to 20

WHERE_NUMERIC_TESTS = [
(lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, None, None, 0, 'int8'),
(lambda f: f > 5, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, -1.0, 'float64'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, None, None, -1.0, 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, shuffle_randstate(list(range(0,20))), 'int64'),
(lambda f: f > 5, "create_numeric", {"nformat": "int32"}, WHERE_NUMERIC_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(0,20))), dtype='int32'), 'int32'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, None, None, np.array(shuffle_randstate(list(range(-20,0))), dtype='float32'), 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int8", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'float64'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'),
(lambda f: f > 5, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA, 'float32'),
(lambda f: f > 5, "create_numeric", {"nformat": "float32"}, WHERE_NUMERIC_FIELD_DATA,"create_numeric", {"nformat": "float64"}, WHERE_NUMERIC_FIELD_DATA, 'float64'),
(WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int16", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA,"create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, 'int32'),

]

WHERE_FIXED_STRING_TESTS = [
(lambda f: f > 5, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA),
(lambda f: f > 2, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2300: can we also do this for float32?

(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 1}, [b"a", b"b", b"e", b"c", b" "]*4 ),
]

WHERE_INDEXED_STRING_TESTS = [
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_indexed_string", {}, shuffle_randstate(WHERE_INDEXED_STRING_FIELD_DATA)),
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_fixed_string", {"length": 3}, WHERE_FIXED_STRING_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_numeric", {"nformat": "int8"}, WHERE_NUMERIC_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA),
(WHERE_BOOLEAN_COND, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, WHERE_CATEGORICAL_FIELD_DATA, "create_indexed_string", {}, WHERE_INDEXED_STRING_FIELD_DATA),
]

def where_oracle(cond, a, b):
if callable(cond):
if isinstance(a, fields.Field):
cond = cond(a.data[:])
elif isinstance(a, list):
cond = cond(np.array(a))
elif isinstance(a, np.ndarray):
cond = cond(a)
return np.where(cond, a, b)


class TestFieldWhereFunctions(SessionTestCase):

@parameterized.expand(WHERE_NUMERIC_TESTS)
def test_module_fields_where(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype):
"""
Test `where` for the numeric fields using `fields.where` function and the object's method.
"""
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)
expected_result = where_oracle(cond, a_field_data, b_data)

if b_kwarg is None:
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"):
if callable(cond):
with self.assertRaises(NotImplementedError) as context:
result = fields.where(cond, a_field, b_data)
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
else:
result = fields.where(cond, a_field, b_data)
np.testing.assert_array_equal(expected_result, result)

else:
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data)
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
if callable(cond):
with self.assertRaises(NotImplementedError) as context:
result = fields.where(cond, a_field, b_field)
self.assertEqual(str(context.exception), "module method `fields.where` doesn't support callable cond, please use instance mehthod `where` for callable cond.")
else:
result = fields.where(cond, a_field, b_field)
np.testing.assert_array_equal(expected_result, result)


@parameterized.expand(WHERE_NUMERIC_TESTS)
def test_instance_field_where_return_numeric_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_data, expected_dtype):
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)

expected_result = where_oracle(cond, a_field_data, b_data)

if b_kwarg is None:
with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_data)}"):
result = a_field.where(cond, b_data)
self.assertEqual(result._nformat, expected_dtype)
np.testing.assert_array_equal(result, expected_result)

else:
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_data)

with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
result = a_field.where(cond, b_field)
self.assertIsInstance(result, fields.NumericMemField)
self.assertEqual(result._nformat, expected_dtype)
np.testing.assert_array_equal(result, expected_result)


@parameterized.expand(WHERE_FIXED_STRING_TESTS)
def test_instance_field_where_return_fixed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data):
a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data)

expected_result = where_oracle(cond, a_field_data, b_field_data)

with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to after the mem fields are created

result = a_field.where(cond, b_field)
self.assertIsInstance(result, fields.FixedStringMemField)
np.testing.assert_array_equal(result.data[:], expected_result)

# reload to test FixedStringMemField
a_mem_field, b_mem_field = a_field, b_field
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to before the first subtest

if isinstance(a_field, fields.FixedStringField):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

condition can be removed

a_mem_field = fields.FixedStringMemField(self.s, a_kwarg["length"])
a_mem_field.data.write(np.array(a_field_data))

if isinstance(b_field, fields.FixedStringField):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

condition can be removed

b_mem_field = fields.FixedStringMemField(self.s, b_kwarg["length"])
b_mem_field.data.write(np.array(b_field_data))

with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all four combinations:
a_field, b_field
a_field, b_mem_field
a_mem_field, b_field
a_mem_field, b_mem_field

result = a_mem_field.where(cond, b_mem_field)
self.assertIsInstance(result, fields.FixedStringMemField)
np.testing.assert_array_equal(result.data[:], expected_result)


@parameterized.expand(WHERE_INDEXED_STRING_TESTS)
def test_instance_field_where_return_indexed_string_mem_field(self, cond, a_creator, a_kwarg, a_field_data, b_creator, b_kwarg, b_field_data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with combinations of hdf5 and mem fields

a_field = self.setup_field(self.df, a_creator, "af", (), a_kwarg, a_field_data)
b_field = self.setup_field(self.df, b_creator, "bf", (), b_kwarg, b_field_data)

expected_result = where_oracle(cond, a_field_data, b_field_data)

with self.subTest(f"Test instance where method: a is {type(a_field)}, b is {type(b_field)}"):
result = a_field.where(cond, b_field)
self.assertIsInstance(result, fields.IndexedStringMemField)
np.testing.assert_array_equal(result.data[:], expected_result)

# reload to test FixedStringMemField
a_mem_field, b_mem_field = a_field, b_field
if isinstance(a_field, fields.IndexedStringMemField):
a_mem_field = fields.IndexedStringMemField(self.s)
a_mem_field.data.write(a_field_data)

if isinstance(b_field, fields.IndexedStringMemField):
b_mem_field = fields.IndexedStringMemField(self.s)
b_mem_field.data.write(b_field_data)

with self.subTest(f"Test instance where method: a is {type(a_mem_field)}, b is {type(b_mem_field)}"):
result = a_mem_field.where(cond, b_mem_field)
self.assertIsInstance(result, fields.IndexedStringMemField)
np.testing.assert_array_equal(result.data[:], expected_result)


class TestFieldModuleFunctions(SessionTestCase):

@parameterized.expand(DEFAULT_FIELD_DATA)
Expand Down