diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 62e62e91..d9926d0f 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -654,7 +654,7 @@ class SimpleArray template using is_simple_array = std::is_same< - std::remove_reference_t, + std::remove_const_t>, SimpleArray::value_type>>; template diff --git a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp index 9bb105af..587782a9 100644 --- a/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp +++ b/cpp/modmesh/buffer/pymod/SimpleArrayCaster.hpp @@ -77,12 +77,6 @@ namespace detail const modmesh::SimpleArray##DATATYPE * array_from_arrayplex = reinterpret_cast(arrayplex.instance_ptr()); \ value = const_cast(array_from_arrayplex); \ return true; \ - } \ - \ - /* Conversion from C++ to Python object */ \ - static pybind11::handle cast(modmesh::SimpleArray##DATATYPE && src, pybind11::return_value_policy policy, pybind11::handle parent) \ - { \ - return base::cast(src, policy, parent); \ } \ } diff --git a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp index 601abbfb..65339705 100644 --- a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp +++ b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp @@ -151,6 +151,34 @@ struct TypeBroadcast } } + static void broadcast(SimpleArray & arr_out, std::vector const & slices, pybind11::object const & py_number) + { + const T assigned_value = py_number.cast(); + + shape_type out_shape(arr_out.ndim()); + for (size_t i = 0; i < arr_out.ndim(); i++) + { + slice_type const & slice = slices[i]; + if ((slice[1] - slice[0]) % slice[2] == 0) + { + out_shape[i] = (slice[1] - slice[0]) / slice[2]; + } + else + { + out_shape[i] = (slice[1] - slice[0]) / slice[2] + 1; + } + } + + shape_type sidx_init(arr_out.ndim()); + + for (size_t i = 0; i < arr_out.ndim(); ++i) + { + sidx_init[i] = 0; + } + + assigned_idx(arr_out, slices, assigned_value, out_shape, sidx_init, static_cast(arr_out.ndim()) - 1); + } + static void broadcast(SimpleArray & arr_out, std::vector const & slices, pybind11::array const & arr_in) { if (dtype_is_type(arr_in)) @@ -230,6 +258,34 @@ struct TypeBroadcast throw std::runtime_error(msg.str()); } + +private: + // NOLINTNEXTLINE(misc-no-recursion) + static void assigned_idx(SimpleArray & arr_out, std::vector const & slices, const T value, shape_type out_shape, shape_type sidx, int dim) + { + if (dim < 0) + { + return; + } + + for (size_t i = 0; i < out_shape[dim]; ++i) + { + sidx[dim] = i; + + size_t offset_out = 0; + for (size_t it = 0; it < arr_out.ndim(); ++it) + { + auto step = slices[it][2]; + offset_out += arr_out.stride(it) * sidx[it] * step; + } + + // NOLINTNEXTLINE(bugprone-signed-char-misuse, cert-str34-c) + arr_out.at(offset_out) = value; + + // recursion here + assigned_idx(arr_out, slices, value, out_shape, sidx, dim - 1); + } + } }; /* end struct TypeBroadCast */ } /* end namespace python */ diff --git a/cpp/modmesh/buffer/pymod/array_common.hpp b/cpp/modmesh/buffer/pymod/array_common.hpp index 0f75f3e7..10240bb9 100644 --- a/cpp/modmesh/buffer/pymod/array_common.hpp +++ b/cpp/modmesh/buffer/pymod/array_common.hpp @@ -89,6 +89,78 @@ class ArrayPropertyHelper } } + static pybind11::object getitem_parser(const SimpleArray & arr, pybind11::args const & args) + { + namespace py = pybind11; + if (args.size() != 1) + { + throw std::runtime_error("unsupported operation."); + } + + const py::object & py_key = args[0]; + + // sarr[x] + if (py::isinstance(py_key)) + { + const auto key = py_key.cast(); + return py::cast(arr.at(key)); + } + + bool is_tuple = py::isinstance(py_key); + bool is_number_tuple = false; + if (is_tuple) + { + const py::tuple tuple_in = py_key; + if (tuple_in.size() > 0) + { + is_number_tuple = py::isinstance(tuple_in[0]); + } + } + + // sarr[x, y, z] + if (is_number_tuple) + { + const auto key = py_key.cast>(); + return py::cast(arr.at(key)); + } + + // multi-dimension with slice and ellipsis + // sarr[slice, slice, ellipsis] + if (is_tuple) + { + const py::tuple tuple_in = py_key; + + auto slices = make_default_slices(arr); + process_slices(tuple_in, slices, arr.ndim()); + + SimpleArray arr_out(get_shape_from_slices(slices)); + + broadcast_array_using_slice(arr_out, slices, to_ndarray(arr)); + return py::cast(arr_out); + } + // one-dimension with slice + // sarr[slice] + if (py::isinstance(py_key)) + { + const auto slice_in = py_key.cast(); + + auto slices = make_default_slices(arr); + copy_slice(slices[0], slice_in); + + SimpleArray arr_out(get_shape_from_slices(slices)); + + broadcast_array_using_slice(arr_out, slices, to_ndarray(arr)); + return py::cast(arr_out); + } + // sarr[ellipsis] + if (py::isinstance(py_key)) + { + return py::cast(arr); + } + + throw std::runtime_error("unsupported operation."); + } + static void setitem_parser(SimpleArray & arr_out, pybind11::args const & args) { namespace py = pybind11; @@ -116,6 +188,40 @@ class ArrayPropertyHelper arr_out.at(key) = py_value.cast(); return; } + // multi-dimension with slice and ellipsis + // sarr[slice, slice, ellipsis] = v + if (py::isinstance(py_key) && is_number) + { + const py::tuple tuple_in = py_key; + + auto slices = make_default_slices(arr_out); + process_slices(tuple_in, slices, arr_out.ndim()); + + broadcast_array_using_slice(arr_out, slices, py_value); + return; + } + // one-dimension with slice + // sarr[slice] = v + if (py::isinstance(py_key) && is_number) + { + const auto slice_in = py_key.cast(); + + auto slices = make_default_slices(arr_out); + copy_slice(slices[0], slice_in); + + broadcast_array_using_slice(arr_out, slices, py_value); + return; + } + // sarr[ellipsis] = V + if (py::isinstance(py_key) && is_number) + { + const auto value = py_value.cast(); + for (ssize_t i = 0; i < arr_out.size(); i++) + { + arr_out.at(i) = value; + } + return; + } const bool is_sequence = py::isinstance(py_value) || py::isinstance(py_value) || py::isinstance(py_value); @@ -297,6 +403,47 @@ class ArrayPropertyHelper arr_out.set_nghost(nghost); } } + + static void broadcast_array_using_slice(SimpleArray & arr_out, + std::vector const & slices, + const pybind11::object & py_number) + { + namespace py = pybind11; + + if (!py::isinstance(py_number) && !py::isinstance(py_number) && !py::isinstance(py_number)) + { + throw std::runtime_error("Cannot broadcast a non-number value to an array."); + } + + const size_t nghost = arr_out.nghost(); + if (0 != nghost) + { + arr_out.set_nghost(0); + } + + TypeBroadcast::broadcast(arr_out, slices, py_number); + + if (0 != nghost) + { + arr_out.set_nghost(nghost); + } + } + + static shape_type get_shape_from_slices(std::vector const & slices) + { + + shape_type shape; + for (auto const & slice : slices) + { + std::cout << slice[0] << ", " << slice[1] << ", " << slice[2] << std::endl; + + shape.push_back((slice[1] - slice[0]) / slice[2]); + + std::cout << ((slice[1] - slice[0]) / slice[2]) << std::endl; + } + std::cout << std::endl; + return shape; + } }; } /* end namespace python */ diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index f5b95fa1..94c59cc9 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -123,15 +123,8 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray return ret; }) .def("__len__", &wrapped_type::size) - .def( - "__getitem__", - [](wrapped_type const & self, ssize_t key) - { return self.at(key); }) - .def( - "__getitem__", - [](wrapped_type const & self, std::vector const & key) - { return self.at(key); }) .def("__setitem__", &property_helper::setitem_parser) + .def("__getitem__", &property_helper::getitem_parser) .def( "reshape", [](wrapped_type const & self, py::object const & shape) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 92e7cca4..47a3b881 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -434,6 +434,111 @@ def test_SimpleArray_from_ndarray_content(self): sarr.ndarray.fill(100) self.assertTrue((ndarr == 100).all()) + def test_SimpleArray_getitem_ellipsis(self): + sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) + + count = 0 + for i in range(2): + for j in range(3): + for k in range(4): + sarr[i, j, k] = count + count += 1 + + sarr2 = sarr[...] + self.assertEqual(sarr.shape, sarr2.shape) + for i in range(2): + for j in range(3): + for k in range(4): + self.assertEqual(sarr[i, j, k], sarr2[i, j, k]) + + def test_SimpleArray_getitem_slice_ellipsis(self): + sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) + ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4)) + + count = 0 + for i in range(2): + for j in range(3): + for k in range(4): + sarr[i, j, k] = count + ndarr[i, j, k] = count + count += 1 + + sarr2 = sarr[::, ..., ::2] + ndarr2 = ndarr[::, ..., ::2] + self.assertEqual(ndarr2.shape, sarr2.shape) + for i in range(2): + for j in range(3): + for k in range(2): + self.assertEqual(ndarr2[i, j, k], sarr2[i, j, k]) + + def test_SimpleArray_getitem_slice(self): + sarr = modmesh.SimpleArrayFloat64(24) + ndarr = np.arange(2 * 3 * 4, dtype='float64') + + count = 0 + for i in range(24): + sarr[i] = count + ndarr[i] = count + count += 1 + + sarr2 = sarr[1:19:3] + ndarr2 = ndarr[1:19:3] + self.assertEqual(ndarr2.size, sarr2.size) + for i in range(ndarr2.size): + self.assertEqual(ndarr2[i], sarr2[i]) + + def test_SimpleArray_broadcast_ellipsis_number(self): + sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) + sarr[...] = 1.234 + for i in range(2): + for j in range(3): + for k in range(4): + self.assertEqual(1.234, sarr[i, j, k]) + + def test_SimpleArray_broadcast_slice_ellipsis_number(self): + sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) + ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4)) + + VALUE1 = 1.234 + VALUE2 = 5.678 + sarr.fill(VALUE1) + ndarr[...] = VALUE1 + + ndarr[0:1, ..., ::2] = VALUE2 + sarr[0:1, ..., ::2] = VALUE2 + + for i in range(2): + for j in range(3): + for k in range(4): + self.assertEqual(ndarr[i, j, k], sarr[i, j, k]) + + def test_SimpleArray_broadcast_slice_number(self): + TOTAL = 20 + + sarr = modmesh.SimpleArrayFloat64(TOTAL) + ndarr = np.arange(100, dtype='float64') + + VALUE1 = 1.234 + VALUE2 = 5.678 + VALUE3 = 9.123 + + sarr.fill(VALUE1) + ndarr[...] = VALUE1 + + ndarr[2:15:3] = VALUE2 + sarr[2:15:3] = VALUE2 + + for i in range(TOTAL): + print(i, ndarr[i], sarr[i]) + for i in range(TOTAL): + self.assertEqual(ndarr[i], sarr[i]) + + ndarr[5:19:2] = VALUE3 + sarr[5:19:2] = VALUE3 + + for i in range(TOTAL): + self.assertEqual(ndarr[i], sarr[i]) + def test_SimpleArray_broadcast_ellipsis_shape(self): sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4)) @@ -641,8 +746,6 @@ def test_SimpleArray_broadcast_slice_ghost_1d(self): sarr[::STEP] = ndarr[...] ndarr2[::STEP] = ndarr[...] - print(ndarr2.shape) - for i in range(0, N, STEP): self.assertEqual(ndarr2[i], sarr[i - G]) diff --git a/thirdparty/PUI b/thirdparty/PUI index 1560f360..ff24318c 160000 --- a/thirdparty/PUI +++ b/thirdparty/PUI @@ -1 +1 @@ -Subproject commit 1560f360f4d5fd62f2204b7e52297338249df350 +Subproject commit ff24318c398d62c9033a75d608b962cd8173afb5