Skip to content

Commit

Permalink
Add complex_as_vector and vector_as_complex functions, reinterpreting…
Browse files Browse the repository at this point in the history
… complex as 2-D float vectors and vice versa.
  • Loading branch information
Ivorforce committed Nov 10, 2024
1 parent e3024b1 commit 94049f7
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 34 deletions.
19 changes: 19 additions & 0 deletions doc_classes/nd.xml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@
No check is performed to ensure a_min < a_max.
</description>
</method>
<method name="complex_as_vector" qualifiers="static">
<return type="NDArray" />
<param index="0" name="v" type="Variant" />
<description>
Reinterprets elements in a complex-typed array as a 2-D vector.
The new dimension is added in the back.
</description>
</method>
<method name="concatenate" qualifiers="static">
<return type="NDArray" />
<param index="0" name="v" type="Variant" />
Expand Down Expand Up @@ -1061,6 +1069,17 @@
Returns a 0-dimension scalar if axes is null. In that case, consider [method ndf.var].
</description>
</method>
<method name="vector_as_complex" qualifiers="static">
<return type="NDArray" />
<param index="0" name="v" type="Variant" />
<param index="1" name="keepdims" type="bool" default="false" />
<param index="2" name="dtype" type="int" enum="nd.DType" default="13" />
<description>
Reinterprets a real valued array with [code]shape[-1] == 2[/code] as a complex valued array.
If [param keepdims] is false (default), the last dimension will be consumed. Otherwise, it will be 1.
This function will return a view if possible, but create a new array otherwise.
</description>
</method>
<method name="vsplit" qualifiers="static">
<return type="NDArray[]" />
<param index="0" name="v" type="Variant" />
Expand Down
34 changes: 34 additions & 0 deletions docs/classes/class_nd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Methods
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`clip<class_nd_method_clip>`\ (\ a\: ``Variant``, min\: ``Variant``, max\: ``Variant``\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`complex_as_vector<class_nd_method_complex_as_vector>`\ (\ v\: ``Variant``\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`concatenate<class_nd_method_concatenate>`\ (\ v\: ``Variant``, axis\: ``int`` = 0, dtype\: :ref:`DType<enum_nd_DType>` = 13\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`conjugate<class_nd_method_conjugate>`\ (\ v\: ``Variant``\ ) |static| |
Expand Down Expand Up @@ -274,6 +276,8 @@ Methods
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`var<class_nd_method_var>`\ (\ a\: ``Variant``, axes\: ``Variant`` = null\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`vector_as_complex<class_nd_method_vector_as_complex>`\ (\ v\: ``Variant``, keepdims\: ``bool`` = false, dtype\: :ref:`DType<enum_nd_DType>` = 13\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`Array<class_Array>`\[:ref:`NDArray<class_NDArray>`\] | :ref:`vsplit<class_nd_method_vsplit>`\ (\ v\: ``Variant``, indices_or_section_size\: ``Variant``\ ) |static| |
+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| :ref:`NDArray<class_NDArray>` | :ref:`vstack<class_nd_method_vstack>`\ (\ v\: ``Variant``, dtype\: :ref:`DType<enum_nd_DType>` = 13\ ) |static| |
Expand Down Expand Up @@ -775,6 +779,20 @@ No check is performed to ensure a_min < a_max.

----

.. _class_nd_method_complex_as_vector:

.. rst-class:: classref-method

:ref:`NDArray<class_NDArray>` **complex_as_vector**\ (\ v\: ``Variant``\ ) |static| :ref:`🔗<class_nd_method_complex_as_vector>`

Reinterprets elements in a complex-typed array as a 2-D vector.

The new dimension is added in the back.

.. rst-class:: classref-item-separator

----

.. _class_nd_method_concatenate:

.. rst-class:: classref-method
Expand Down Expand Up @@ -2133,6 +2151,22 @@ Returns a 0-dimension scalar if axes is null. In that case, consider :ref:`ndf.v

----

.. _class_nd_method_vector_as_complex:

.. rst-class:: classref-method

:ref:`NDArray<class_NDArray>` **vector_as_complex**\ (\ v\: ``Variant``, keepdims\: ``bool`` = false, dtype\: :ref:`DType<enum_nd_DType>` = 13\ ) |static| :ref:`🔗<class_nd_method_vector_as_complex>`

Reinterprets a real valued array with ``shape[-1] == 2`` as a complex valued array.

If ``keepdims`` is false (default), the last dimension will be consumed. Otherwise, it will be 1.

This function will return a view if possible, but create a new array otherwise.

.. rst-class:: classref-item-separator

----

.. _class_nd_method_vsplit:

.. rst-class:: classref-method
Expand Down
1 change: 1 addition & 0 deletions docs/setup/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Upcoming Changes (main branch)

- Added complex numbers data types (``complex64`` and ``complex128``).
- Added ``real``, ``imag``, ``conjugate`` and ``angle`` functions for complex numbers.
- Added ``complex_as_vector`` and ``vector_as_complex`` functions for convenient complex number creation and manipulation, similar to ``real`` and ``imag``.
- Added ``any`` layout type, which may bring tiny speed improvements.
- Added ``fft`` function.
- Added ``pad`` function.
Expand Down
14 changes: 14 additions & 0 deletions src/nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ void nd::_bind_methods() {
godot::ClassDB::bind_static_method("nd", D_METHOD("imag", "v"), &nd::imag);
godot::ClassDB::bind_static_method("nd", D_METHOD("conjugate", "v"), &nd::conjugate);
godot::ClassDB::bind_static_method("nd", D_METHOD("angle", "v"), &nd::angle);
godot::ClassDB::bind_static_method("nd", D_METHOD("vector_as_complex", "v", "keepdims", "dtype"), &nd::vector_as_complex, DEFVAL(false), DEFVAL(nd::DType::DTypeMax));
godot::ClassDB::bind_static_method("nd", D_METHOD("complex_as_vector", "v"), &nd::complex_as_vector);

godot::ClassDB::bind_static_method("nd", D_METHOD("positive", "a"), &nd::positive);
godot::ClassDB::bind_static_method("nd", D_METHOD("negative", "a"), &nd::negative);
Expand Down Expand Up @@ -840,6 +842,18 @@ Ref<NDArray> nd::angle(const Variant& a) {
}, a);
}

Ref<NDArray> nd::vector_as_complex(const Variant& a, bool keepdims, DType dtype) {
return map_variants_as_arrays([keepdims, dtype](const std::shared_ptr<va::VArray>& varray) {
return va::vector_as_complex(va::store::default_allocator, *varray, dtype, keepdims);
}, a);
}

Ref<NDArray> nd::complex_as_vector(const Variant& a) {
return map_variants_as_arrays([](const std::shared_ptr<va::VArray>& varray) {
return va::complex_as_vector(varray);
}, a);
}

Ref<NDArray> nd::positive(const Variant& a) {
return VARRAY_MAP1(positive, a);
}
Expand Down
2 changes: 2 additions & 0 deletions src/nd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class nd : public Object {
static Ref<NDArray> imag(const Variant& a);
static Ref<NDArray> conjugate(const Variant& a);
static Ref<NDArray> angle(const Variant& a);
static Ref<NDArray> vector_as_complex(const Variant& a, bool keepdims, DType dtype = DType::DTypeMax);
static Ref<NDArray> complex_as_vector(const Variant& a);

// Basic math functions.
static Ref<NDArray> positive(const Variant& a);
Expand Down
141 changes: 107 additions & 34 deletions src/vatensor/rearrange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <functional> // for multiplies
#include <numeric> // for accumulate, iota
#include <set> // for operator==, set

#include "create.hpp"
#include "util.hpp"
#include "vpromote.hpp"
#include "xscalar_store.hpp"
Expand Down Expand Up @@ -203,32 +205,31 @@ std::shared_ptr<VArray> va::join_axes_into_last_dimension(const VArray& varray,
);
}

template <typename T>
std::shared_ptr<VArray> reinterpret_complex_as_floats(const VArray& varray, const T& carray, std::ptrdiff_t offset) {
using V = typename std::decay_t<decltype(carray)>::value_type;

auto new_strides = carray.strides();
for (auto& stride : new_strides) { stride *= 2; }

return std::make_shared<VArray>(VArray {
std::shared_ptr(varray.store),
make_compute(
reinterpret_cast<typename V::value_type*>(const_cast<V*>(carray.data())) + offset,
carray.shape(),
new_strides,
xt::layout_type::dynamic
),
varray.data_offset * 2 + offset
});
}

std::shared_ptr<VArray> va::real(const std::shared_ptr<VArray>& varray) {
std::shared_ptr<VArray> reinterpret_complex_as_floats(const std::shared_ptr<VArray>& varray, std::ptrdiff_t offset, bool add_dimension) {
return std::visit(
[&varray](auto& carray) -> std::shared_ptr<VArray> {
using V = typename std::decay_t<decltype(carray)>::value_type;
[&varray, offset, add_dimension](auto& carray) -> std::shared_ptr<VArray> {
using V = typename std::decay_t<decltype(carray)>::value_type;

if constexpr (xtl::is_complex<V>::value) {
return reinterpret_complex_as_floats(*varray, carray, 0);
using V = typename std::decay_t<decltype(carray)>::value_type;

strides_type new_strides = carray.strides();
for (auto& stride : new_strides) { stride *= 2; }
if (add_dimension) new_strides.push_back(1);

shape_type new_shape = carray.shape();
if (add_dimension) new_shape.push_back(2);

return std::make_shared<VArray>(VArray {
std::shared_ptr(varray->store),
make_compute(
reinterpret_cast<typename V::value_type*>(const_cast<V*>(carray.data())) + offset,
new_shape,
new_strides,
(add_dimension && (carray.layout() == xt::layout_type::row_major || carray.layout() == xt::layout_type::any)) ? xt::layout_type::row_major : xt::layout_type::dynamic
),
varray->data_offset * 2 + offset
});
}
else {
return varray;
Expand All @@ -237,16 +238,88 @@ std::shared_ptr<VArray> va::real(const std::shared_ptr<VArray>& varray) {
);
}

std::shared_ptr<VArray> va::real(const std::shared_ptr<VArray>& varray) {
return reinterpret_complex_as_floats(varray, 0, false);
}

std::shared_ptr<VArray> va::imag(const std::shared_ptr<VArray>& varray) {
return std::visit(
[&varray](auto& carray) -> std::shared_ptr<VArray> {
using V = typename std::decay_t<decltype(carray)>::value_type;
if constexpr (xtl::is_complex<V>::value) {
return reinterpret_complex_as_floats(*varray, carray, 1);
}
else {
return va::store::full_dummy_like(0, carray);
}
}, varray->data
);
return reinterpret_complex_as_floats(varray, 1, false);
}

std::shared_ptr<VArray> va::complex_as_vector(const std::shared_ptr<VArray>& varray) {
return reinterpret_complex_as_floats(varray, 0, true);
}

std::shared_ptr<VArray> va::vector_as_complex(VStoreAllocator& allocator, const VArray& varray, DType dtype, bool keepdims) {
const auto dim_count = varray.dimension();

if (dim_count < 1) { throw std::invalid_argument("Array must have at least one dimension"); }

const auto& strides = varray.strides();
const auto& shape = varray.shape();

if (shape.back() != 2) { throw std::invalid_argument("Last dimension shape must be 2"); }

if (strides.back() == 1 && std::visit([dtype](auto& carray) -> bool {
using V = typename std::decay_t<decltype(carray)>::value_type;

if constexpr (xtl::is_complex<V>::value) {
throw std::runtime_error("Complex vector cannot be reinterpreted as real vector");
}
else if constexpr (std::is_floating_point_v<V>) {
return dtype == DTypeMax || dtype == dtype_of_type<std::complex<V>>();
}
else {
return false;
}
}, varray.data)) {
// Can return a view!

// Remove last dimension.
auto new_strides = strides_type(dim_count - (keepdims ? 0 : 1));
for (int i = 0; i < dim_count - 1; ++i) { new_strides[i] = strides[i] / 2; }
if (keepdims) { new_strides.back() = 0; }

auto new_shape = shape_type(dim_count - 1);
std::copy_n(shape.begin(), dim_count - 1, new_shape.begin());
if (keepdims) { new_shape.back() = 1; }

const auto new_layout = keepdims ? xt::layout_type::dynamic : varray.layout();

return std::make_shared<VArray>(VArray {
std::shared_ptr(varray.store),
std::visit([&new_shape, &new_strides, &new_layout](auto& carray) -> VData {
using V = typename std::decay_t<decltype(carray)>::value_type;

if constexpr (!std::is_floating_point_v<V>) {
throw std::runtime_error("internal error");
}
else {
return make_compute(
reinterpret_cast<std::complex<V>*>(const_cast<V*>(carray.data())),
new_shape,
new_strides,
new_layout
);
}
}, varray.data),
varray.data_offset * 2
});
}

// Need to return a copy.
if (dtype == DTypeMax) { dtype = DType::Complex128; }

const DType comp_dtype = std::visit([](auto t) -> DType {
if constexpr (xtl::is_complex<decltype(t)>::value) {
return dtype_of_type<typename decltype(t)::value_type>();
}
else {
throw std::runtime_error("DType must be complex");
}
}, dtype_to_variant(dtype));

const auto float_array = va::copy_as_dtype(allocator, varray.data, comp_dtype);
// Call ourselves again, though this time we should get a view for sure.
return vector_as_complex(allocator, *float_array, dtype, keepdims);
}
2 changes: 2 additions & 0 deletions src/vatensor/rearrange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace va {

std::shared_ptr<VArray> real(const std::shared_ptr<VArray>& varray);
std::shared_ptr<VArray> imag(const std::shared_ptr<VArray>& varray);
std::shared_ptr<VArray> complex_as_vector(const std::shared_ptr<VArray>& varray);
std::shared_ptr<VArray> vector_as_complex(VStoreAllocator& allocator, const VArray& varray, DType dtype, bool keepdims);
}

#endif //XV_H

0 comments on commit 94049f7

Please sign in to comment.