From 81ce3865b122ac16c7d6ac74f78957d60fdb5515 Mon Sep 17 00:00:00 2001 From: fverdugo Date: Fri, 7 Jun 2019 17:42:23 +0200 Subject: [PATCH] apply on IndexCellArrays --- src/ArrayOperations.jl | 38 ++++++++++++++++++++++++++++++++++++ test/ArrayOperationsTests.jl | 29 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/src/ArrayOperations.jl b/src/ArrayOperations.jl index 301b226..38771ff 100644 --- a/src/ArrayOperations.jl +++ b/src/ArrayOperations.jl @@ -2,6 +2,7 @@ module ArrayOperations using CellwiseValues using CellwiseValues.NumberOperations: _checks +using CellwiseValues.NumberOperations: _getvalues using CellwiseValues.Kernels: _size_for_broadcast import CellwiseValues: apply @@ -16,6 +17,10 @@ function apply(k::ArrayKernel,v::Vararg{<:CellValue}) CellArrayFromKernel(k,v...) end +function apply(k::ArrayKernel,v::Vararg{<:IndexCellValue}) + IndexCellArrayFromKernel(k,v...) +end + function _apply(f,v,::Val{true}) k = ArrayKernelFromBroadcastedFunction(f) apply(k,v...) @@ -97,5 +102,38 @@ _compute_sizes(a1,a2,a3) = (_bs(a1),_bs(a2),_bs(a3)) _compute_sizes(a1,a2,a3,a4) = (_bs(a1),_bs(a2),_bs(a3),_bs(a4)) _compute_sizes(a1,a2,a3,a4,a5) = (_bs(a1),_bs(a2),_bs(a3),_bs(a4),_bs(a5)) _compute_sizes(a1,a2,a3,a4,a5,a6) = (_bs(a1),_bs(a2),_bs(a3),_bs(a4),_bs(a5),_bs(a6)) + +struct IndexCellArrayFromKernel{T,N,K,V} <: IndexCellArray{T,N,CachedArray{T,N,Array{T,N}},1} + kernel::K + cellvalues::V + cache::CachedArray{T,N,Array{T,N}} +end + +function IndexCellArrayFromKernel(k::ArrayKernel,v::Vararg{<:CellValue}) + _checks(v) + T = _compute_type(k,v) + N = _compute_N(v) + K = typeof(k) + V = typeof(v) + cache = CachedArray(T,N) + IndexCellArrayFromKernel{T,N,K,V}(k,v,cache) +end + +function length(self::IndexCellArrayFromKernel) + vi, = self.cellvalues + length(vi) +end + +size(self::IndexCellArrayFromKernel) = (length(self),) + +function getindex(self::IndexCellArrayFromKernel,i::Integer) + a = _getvalues(i,self.cellvalues...) + s = _compute_sizes(a...) + z = compute_size(self.kernel,s...) + v = self.cache + setsize!(v,z) + compute_value!(v,self.kernel,a...) + v +end end # module ArrayOperations diff --git a/test/ArrayOperationsTests.jl b/test/ArrayOperationsTests.jl index 2fc643b..3c3ecf7 100644 --- a/test/ArrayOperationsTests.jl +++ b/test/ArrayOperationsTests.jl @@ -36,4 +36,33 @@ for (a,b) in zip(ax,bx) test_iter_cell_array(w,o) end +a = [1,2,3] +v = TestIndexCellValue(a,l) +w = apply(-,v,broadcast=true) +o = [ CachedArray(-vi) for vi in v ] +test_index_cell_array(w,o) + +a = [1,2,3] +b = [3,2,1] +u = TestIndexCellValue(a,l) +v = TestIndexCellValue(b,l) +w = apply(-,u,v,broadcast=true) +o = [ CachedArray(ui.-vi) for (ui,vi) in zip(u,v) ] +test_index_cell_array(w,o) + +v1 = VectorValue(2,3) +v2 = VectorValue(3,2) +v3 = VectorValue(1,2) + +ax = [rand(1,3,4), 1.0 , [v1,v2,v3], [v1,v2,v3], 1 ] +bx = [rand(2,3,1), rand(2,3), [v2,v3,v1], v1 , [v2,v3,v1]] + +for (a,b) in zip(ax,bx) + u = TestIndexCellValue(a,l) + v = TestIndexCellValue(b,l) + o = [ CachedArray(ui.-vi) for (ui,vi) in zip(u,v) ] + w = apply(-,u,v,broadcast=true) + test_index_cell_array(w,o) +end + end # module ArrayOperationsTests