Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement array wrapper getitem #327

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ class SimpleArray

template <typename S>
using is_simple_array = std::is_same<
std::remove_reference_t<S>,
std::remove_const_t<std::remove_reference_t<S>>,
SimpleArray<typename std::remove_reference_t<S>::value_type>>;

template <typename S>
Expand Down
6 changes: 0 additions & 6 deletions cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ namespace detail
const modmesh::SimpleArray##DATATYPE * array_from_arrayplex = reinterpret_cast<const modmesh::SimpleArray##DATATYPE *>(arrayplex.instance_ptr()); \
value = const_cast<modmesh::SimpleArray##DATATYPE *>(array_from_arrayplex); \
return true; \
} \
\
/* Conversion from C++ to Python object */ \
static pybind11::handle cast(modmesh::SimpleArray##DATATYPE && src, pybind11::return_value_policy policy, pybind11::handle parent) \
{ \
return base::cast(src, policy, parent); \
} \
}

Expand Down
56 changes: 56 additions & 0 deletions cpp/modmesh/buffer/pymod/TypeBroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,34 @@ struct TypeBroadcast
}
}

static void broadcast(SimpleArray<T> & arr_out, std::vector<slice_type> const & slices, pybind11::object const & py_number)
{
const T assigned_value = py_number.cast<T>();

shape_type out_shape(arr_out.ndim());
for (size_t i = 0; i < arr_out.ndim(); i++)
{
slice_type const & slice = slices[i];
if ((slice[1] - slice[0]) % slice[2] == 0)
{
out_shape[i] = (slice[1] - slice[0]) / slice[2];
}
else
{
out_shape[i] = (slice[1] - slice[0]) / slice[2] + 1;
}
}

shape_type sidx_init(arr_out.ndim());

for (size_t i = 0; i < arr_out.ndim(); ++i)
{
sidx_init[i] = 0;
}

assigned_idx(arr_out, slices, assigned_value, out_shape, sidx_init, static_cast<int>(arr_out.ndim()) - 1);
}

static void broadcast(SimpleArray<T> & arr_out, std::vector<slice_type> const & slices, pybind11::array const & arr_in)
{
if (dtype_is_type<bool>(arr_in))
Expand Down Expand Up @@ -230,6 +258,34 @@ struct TypeBroadcast

throw std::runtime_error(msg.str());
}

private:
// NOLINTNEXTLINE(misc-no-recursion)
static void assigned_idx(SimpleArray<T> & arr_out, std::vector<slice_type> const & slices, const T value, shape_type out_shape, shape_type sidx, int dim)
{
if (dim < 0)
{
return;
}

for (size_t i = 0; i < out_shape[dim]; ++i)
{
sidx[dim] = i;

size_t offset_out = 0;
for (size_t it = 0; it < arr_out.ndim(); ++it)
{
auto step = slices[it][2];
offset_out += arr_out.stride(it) * sidx[it] * step;
}

// NOLINTNEXTLINE(bugprone-signed-char-misuse, cert-str34-c)
arr_out.at(offset_out) = value;

// recursion here
assigned_idx(arr_out, slices, value, out_shape, sidx, dim - 1);
}
}
}; /* end struct TypeBroadCast */

} /* end namespace python */
Expand Down
147 changes: 147 additions & 0 deletions cpp/modmesh/buffer/pymod/array_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,78 @@ class ArrayPropertyHelper
}
}

static pybind11::object getitem_parser(const SimpleArray<T> & arr, pybind11::args const & args)
{
namespace py = pybind11;
if (args.size() != 1)
{
throw std::runtime_error("unsupported operation.");
}

const py::object & py_key = args[0];

// sarr[x]
if (py::isinstance<py::int_>(py_key))
{
const auto key = py_key.cast<ssize_t>();
return py::cast(arr.at(key));
}

bool is_tuple = py::isinstance<py::tuple>(py_key);
bool is_number_tuple = false;
if (is_tuple)
{
const py::tuple tuple_in = py_key;
if (tuple_in.size() > 0)
{
is_number_tuple = py::isinstance<py::int_>(tuple_in[0]);
}
}

// sarr[x, y, z]
if (is_number_tuple)
{
const auto key = py_key.cast<std::vector<ssize_t>>();
return py::cast(arr.at(key));
}

// multi-dimension with slice and ellipsis
// sarr[slice, slice, ellipsis]
if (is_tuple)
{
const py::tuple tuple_in = py_key;

auto slices = make_default_slices(arr);
process_slices(tuple_in, slices, arr.ndim());

SimpleArray<T> arr_out(get_shape_from_slices(slices));

broadcast_array_using_slice(arr_out, slices, to_ndarray(arr));
return py::cast(arr_out);
}
// one-dimension with slice
// sarr[slice]
if (py::isinstance<py::slice>(py_key))
{
const auto slice_in = py_key.cast<py::slice>();

auto slices = make_default_slices(arr);
copy_slice(slices[0], slice_in);

SimpleArray<T> arr_out(get_shape_from_slices(slices));

broadcast_array_using_slice(arr_out, slices, to_ndarray(arr));
return py::cast(arr_out);
}
// sarr[ellipsis]
if (py::isinstance<py::ellipsis>(py_key))
{
return py::cast(arr);
}

throw std::runtime_error("unsupported operation.");
}

static void setitem_parser(SimpleArray<T> & arr_out, pybind11::args const & args)
{
namespace py = pybind11;
Expand Down Expand Up @@ -116,6 +188,40 @@ class ArrayPropertyHelper
arr_out.at(key) = py_value.cast<T>();
return;
}
// multi-dimension with slice and ellipsis
// sarr[slice, slice, ellipsis] = v
if (py::isinstance<py::tuple>(py_key) && is_number)
{
const py::tuple tuple_in = py_key;

auto slices = make_default_slices(arr_out);
process_slices(tuple_in, slices, arr_out.ndim());

broadcast_array_using_slice(arr_out, slices, py_value);
return;
}
// one-dimension with slice
// sarr[slice] = v
if (py::isinstance<py::slice>(py_key) && is_number)
{
const auto slice_in = py_key.cast<py::slice>();

auto slices = make_default_slices(arr_out);
copy_slice(slices[0], slice_in);

broadcast_array_using_slice(arr_out, slices, py_value);
return;
}
// sarr[ellipsis] = V
if (py::isinstance<py::ellipsis>(py_key) && is_number)
{
const auto value = py_value.cast<T>();
for (ssize_t i = 0; i < arr_out.size(); i++)
{
arr_out.at(i) = value;
}
return;
}

const bool is_sequence = py::isinstance<py::list>(py_value) || py::isinstance<py::array>(py_value) || py::isinstance<py::tuple>(py_value);

Expand Down Expand Up @@ -297,6 +403,47 @@ class ArrayPropertyHelper
arr_out.set_nghost(nghost);
}
}

static void broadcast_array_using_slice(SimpleArray<T> & arr_out,
std::vector<slice_type> const & slices,
const pybind11::object & py_number)
{
namespace py = pybind11;

if (!py::isinstance<py::bool_>(py_number) && !py::isinstance<py::int_>(py_number) && !py::isinstance<py::float_>(py_number))
{
throw std::runtime_error("Cannot broadcast a non-number value to an array.");
}

const size_t nghost = arr_out.nghost();
if (0 != nghost)
{
arr_out.set_nghost(0);
}

TypeBroadcast<T>::broadcast(arr_out, slices, py_number);

if (0 != nghost)
{
arr_out.set_nghost(nghost);
}
}

static shape_type get_shape_from_slices(std::vector<slice_type> const & slices)
{

shape_type shape;
for (auto const & slice : slices)
{
std::cout << slice[0] << ", " << slice[1] << ", " << slice[2] << std::endl;

shape.push_back((slice[1] - slice[0]) / slice[2]);

std::cout << ((slice[1] - slice[0]) / slice[2]) << std::endl;
}
std::cout << std::endl;
return shape;
}
};

} /* end namespace python */
Expand Down
9 changes: 1 addition & 8 deletions cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,8 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray
return ret;
})
.def("__len__", &wrapped_type::size)
.def(
"__getitem__",
[](wrapped_type const & self, ssize_t key)
{ return self.at(key); })
.def(
"__getitem__",
[](wrapped_type const & self, std::vector<ssize_t> const & key)
{ return self.at(key); })
.def("__setitem__", &property_helper::setitem_parser)
.def("__getitem__", &property_helper::getitem_parser)
.def(
"reshape",
[](wrapped_type const & self, py::object const & shape)
Expand Down
107 changes: 105 additions & 2 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,111 @@ def test_SimpleArray_from_ndarray_content(self):
sarr.ndarray.fill(100)
self.assertTrue((ndarr == 100).all())

def test_SimpleArray_getitem_ellipsis(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))

count = 0
for i in range(2):
for j in range(3):
for k in range(4):
sarr[i, j, k] = count
count += 1

sarr2 = sarr[...]
self.assertEqual(sarr.shape, sarr2.shape)
for i in range(2):
for j in range(3):
for k in range(4):
self.assertEqual(sarr[i, j, k], sarr2[i, j, k])

def test_SimpleArray_getitem_slice_ellipsis(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))

count = 0
for i in range(2):
for j in range(3):
for k in range(4):
sarr[i, j, k] = count
ndarr[i, j, k] = count
count += 1

sarr2 = sarr[::, ..., ::2]
ndarr2 = ndarr[::, ..., ::2]
self.assertEqual(ndarr2.shape, sarr2.shape)
for i in range(2):
for j in range(3):
for k in range(2):
self.assertEqual(ndarr2[i, j, k], sarr2[i, j, k])

def test_SimpleArray_getitem_slice(self):
sarr = modmesh.SimpleArrayFloat64(24)
ndarr = np.arange(2 * 3 * 4, dtype='float64')

count = 0
for i in range(24):
sarr[i] = count
ndarr[i] = count
count += 1

sarr2 = sarr[1:19:3]
ndarr2 = ndarr[1:19:3]
self.assertEqual(ndarr2.size, sarr2.size)
for i in range(ndarr2.size):
self.assertEqual(ndarr2[i], sarr2[i])

def test_SimpleArray_broadcast_ellipsis_number(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
sarr[...] = 1.234
for i in range(2):
for j in range(3):
for k in range(4):
self.assertEqual(1.234, sarr[i, j, k])

def test_SimpleArray_broadcast_slice_ellipsis_number(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))

VALUE1 = 1.234
VALUE2 = 5.678
sarr.fill(VALUE1)
ndarr[...] = VALUE1

ndarr[0:1, ..., ::2] = VALUE2
sarr[0:1, ..., ::2] = VALUE2

for i in range(2):
for j in range(3):
for k in range(4):
self.assertEqual(ndarr[i, j, k], sarr[i, j, k])

def test_SimpleArray_broadcast_slice_number(self):
TOTAL = 20

sarr = modmesh.SimpleArrayFloat64(TOTAL)
ndarr = np.arange(100, dtype='float64')

VALUE1 = 1.234
VALUE2 = 5.678
VALUE3 = 9.123

sarr.fill(VALUE1)
ndarr[...] = VALUE1

ndarr[2:15:3] = VALUE2
sarr[2:15:3] = VALUE2

for i in range(TOTAL):
print(i, ndarr[i], sarr[i])
for i in range(TOTAL):
self.assertEqual(ndarr[i], sarr[i])

ndarr[5:19:2] = VALUE3
sarr[5:19:2] = VALUE3

for i in range(TOTAL):
self.assertEqual(ndarr[i], sarr[i])

def test_SimpleArray_broadcast_ellipsis_shape(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))
Expand Down Expand Up @@ -641,8 +746,6 @@ def test_SimpleArray_broadcast_slice_ghost_1d(self):
sarr[::STEP] = ndarr[...]
ndarr2[::STEP] = ndarr[...]

print(ndarr2.shape)

for i in range(0, N, STEP):
self.assertEqual(ndarr2[i], sarr[i - G])

Expand Down
Loading
Loading