diff --git a/lib/orchpy/arrays/dist/core.py b/lib/orchpy/arrays/dist/core.py index 08d379b3..6feaa94d 100644 --- a/lib/orchpy/arrays/dist/core.py +++ b/lib/orchpy/arrays/dist/core.py @@ -4,7 +4,8 @@ import orchpy as op __all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy", - "eye", "triu", "tril", "blockwise_dot", "dot", "block_column", "block_row"] + "eye", "triu", "tril", "blockwise_dot", "dot", "block_column", + "block_row", "gather", "shape"] BLOCK_SIZE = 10 @@ -202,3 +203,23 @@ def block_row(a, row): result = DistArray() result.construct(shape, a.objrefs[row, :]) return result + +@op.distributed([np.ndarray], [List[tuple]]) +def shape(objrefs): + shapes = [single.shape(objref) for objref in objrefs] + return [op.pull(shape) for shape in shapes] + +@op.distributed([np.ndarray, np.ndarray, int], [np.ndarray]) +def gather(objrefs, indices, axis): + perm = indices.argsort() + sorted_indices = indices[perm] + shapes = op.pull(shape(objrefs)) + cumsizes = np.zeros(len(shapes) + 1, dtype="int64") # + 1 for leading zero + np.array([s[axis] for s in shapes]).cumsum(out=cumsizes[1:]) + ranges = np.searchsorted(sorted_indices, cumsizes[1:]) + idx = 0 + results = [] + for i, nextidx in enumerate(ranges): + results.append(single.gather(objrefs[i], sorted_indices[idx:nextidx] - cumsizes[i])) + idx = nextidx + return np.vstack([op.pull(r) for r in results])[indices] diff --git a/lib/orchpy/arrays/single/__init__.py b/lib/orchpy/arrays/single/__init__.py index 967afbb6..0f9d7834 100644 --- a/lib/orchpy/arrays/single/__init__.py +++ b/lib/orchpy/arrays/single/__init__.py @@ -1,2 +1,2 @@ import random, linalg -from core import zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray, copy, tril, triu +from core import zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray, copy, tril, triu, gather, shape diff --git a/lib/orchpy/arrays/single/core.py b/lib/orchpy/arrays/single/core.py index e634dd8f..a6360118 100644 --- a/lib/orchpy/arrays/single/core.py +++ b/lib/orchpy/arrays/single/core.py @@ -49,3 +49,11 @@ def tril(a): @op.distributed([np.ndarray], [np.ndarray]) def triu(a): return np.triu(a) + +@op.distributed([np.ndarray, np.ndarray], [np.ndarray]) +def gather(a, indices): + return a[indices] + +@op.distributed([np.ndarray], [tuple]) +def shape(a): + return a.shape diff --git a/src/orchpylib.cc b/src/orchpylib.cc index 293e1fb7..7fea77d9 100644 --- a/src/orchpylib.cc +++ b/src/orchpylib.cc @@ -237,6 +237,13 @@ int serialize(PyObject* val, Obj* obj) { } } break; + case NPY_INT64: { + npy_int64* buffer = (npy_int64*) PyArray_DATA(array); + for (npy_intp i = 0; i < size; ++i) { + data->add_int_data(buffer[i]); + } + } + break; case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array); while (PyArray_ITER_NOTDONE(iter)) { @@ -327,6 +334,13 @@ PyObject* deserialize(const Obj& obj) { } } break; + case NPY_INT64: { + npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray); + for (npy_intp i = 0; i < size; ++i) { + buffer[i] = array.int_data(i); + } + } + break; default: PyErr_SetString(OrchPyError, "deserialization: internal error (array type not implemented)"); return NULL; diff --git a/test/arrays_test.py b/test/arrays_test.py index f2b6e972..eac116a2 100644 --- a/test/arrays_test.py +++ b/test/arrays_test.py @@ -191,6 +191,17 @@ def testMethods(self): z = dist.dot(x, y) self.assertTrue(np.allclose(orchpy.pull(dist.assemble(z)), np.dot(orchpy.pull(dist.assemble(x)), orchpy.pull(dist.assemble(y))))) + x = single.random.normal([25, 49]) + y = single.random.normal([13, 49]) + z = single.random.normal([100, 49]) + w = single.random.normal([2, 49]) + shapes = orchpy.pull(dist.shape(np.array([x, y, z, w]))) + self.assertTrue(np.array([shape[0] for shape in shapes]).sum() == 25 + 13 + 100 + 2) + perm = np.random.permutation(25 + 13 + 100 + 2) + result = np.vstack([orchpy.pull(x), orchpy.pull(y), orchpy.pull(z), orchpy.pull(w)]) + u = dist.gather(np.array([x, y, z, w], dtype=object), perm, 0) + self.assertTrue(np.alltrue(result[perm,:] == orchpy.pull(u))) + services.cleanup() if __name__ == '__main__':