From b838693284b818320f7d83919b8f6fed7dc8abfa Mon Sep 17 00:00:00 2001 From: Peter Abeles Date: Tue, 3 Nov 2020 17:12:16 -0800 Subject: [PATCH] CommonOps_DSCC - Speed up more mult variants --- change.txt | 3 +- .../csc/BenchmarkMatrixMultDense_DSCC.java | 16 ++- .../csc/BenchmarkMatrixMultDense_MT_DSCC.java | 13 +- .../org/ejml/sparse/csc/CommonOps_DSCC.java | 32 +++-- .../ejml/sparse/csc/CommonOps_MT_DSCC.java | 78 +++++++++-- .../csc/misc/ImplCommonOps_MT_DSCC.java | 6 +- .../csc/mult/ImplMultiplication_DSCC.java | 90 +++++++----- .../csc/mult/ImplMultiplication_MT_DSCC.java | 130 ++++++++++++------ .../ejml/sparse/csc/TestCommonOps_DSCC.java | 12 +- .../sparse/csc/TestCommonOps_MT_DSCC.java | 53 ++++--- .../csc/misc/TestImplCommonOps_MT_DSCC.java | 6 +- .../csc/mult/TestImplMultiplication_DSCC.java | 9 +- .../mult/TestImplMultiplication_MT_DSCC.java | 56 ++++++-- .../simple/ops/SimpleOperations_DSCC.java | 4 +- 14 files changed, 343 insertions(+), 165 deletions(-) diff --git a/change.txt b/change.txt index e8284ba47..c1d255dca 100644 --- a/change.txt +++ b/change.txt @@ -29,8 +29,9 @@ Date format: year/month/day * Removed the functions below since they had a runtime complexity of O(N^2) relative to matrix size instead of O(N) - multTransA(S,S,S), multTransB(S,S,S), innerProductLower(S,S,S) - Thanks Florentin Dörre for first noticing the performance issue + * Speed up multTransAB(S,D,D), multTransA(S,D,D), multTransB(S,D,D) by a large margin - DMatrixSparseCSC - * If sorted a binary search is used to lookup rows. Thanks Florentin Dörre. + * If sorted, a binary search is used to lookup rows. Thanks Florentin Dörre. - ReadMatrixCsv * Thanks DEDZTBH for fixing an indexing error when reading complex data types - Added Concurrent Algorithms diff --git a/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_DSCC.java b/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_DSCC.java index 061452576..5a1d93295 100644 --- a/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_DSCC.java +++ b/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_DSCC.java @@ -18,6 +18,7 @@ package org.ejml.sparse.csc; +import org.ejml.data.DGrowArray; import org.ejml.data.DMatrixRMaj; import org.ejml.data.DMatrixSparseCSC; import org.ejml.dense.row.RandomMatrices_DDRM; @@ -47,27 +48,30 @@ public class BenchmarkMatrixMultDense_DSCC { @Param({"100000"}) private int elementCount; - DMatrixSparseCSC A; + DMatrixSparseCSC A,A_small; DMatrixRMaj B = new DMatrixRMaj(1, 1); DMatrixRMaj C = new DMatrixRMaj(1, 1); + DGrowArray work = new DGrowArray(); + @Setup public void setup() { Random rand = new Random(2345); A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand); + A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand); B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand); C = B.create(dimension, dimension); } @Benchmark public void mult() { CommonOps_DSCC.mult(A, B, C); } @Benchmark public void multAdd() { CommonOps_DSCC.multAdd(A, B, C); } - @Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C); } - @Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C); } - @Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C); } - @Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C); } + @Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C, work); } + @Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C, work); } + @Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C, work); } + @Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C, work); } @Benchmark public void multTransAB() { CommonOps_DSCC.multTransAB(A, B, C); } @Benchmark public void multAddTransAB() { CommonOps_DSCC.multAddTransAB(A, B, C); } - @Benchmark public void invert() { CommonOps_DSCC.invert(A, C); } + @Benchmark public void invert() { CommonOps_DSCC.invert(A_small, C); } public static void main( String[] args ) throws RunnerException { Options opt = new OptionsBuilder() diff --git a/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_MT_DSCC.java b/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_MT_DSCC.java index 665eda3d9..d2dc56eee 100644 --- a/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_MT_DSCC.java +++ b/main/ejml-dsparse/benchmarks/src/org/ejml/sparse/csc/BenchmarkMatrixMultDense_MT_DSCC.java @@ -49,7 +49,7 @@ public class BenchmarkMatrixMultDense_MT_DSCC { @Param({"100000"}) private int elementCount; - DMatrixSparseCSC A; + DMatrixSparseCSC A,A_small; DMatrixRMaj B = new DMatrixRMaj(1, 1); DMatrixRMaj C = new DMatrixRMaj(1, 1); @@ -59,19 +59,20 @@ public class BenchmarkMatrixMultDense_MT_DSCC { public void setup() { Random rand = new Random(2345); A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand); + A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand); B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand); C = B.create(dimension, dimension); } @Benchmark public void mult() { CommonOps_MT_DSCC.mult(A, B, C, work); } @Benchmark public void multAdd() { CommonOps_MT_DSCC.multAdd(A, B, C, work); } - @Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C); } - @Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C); } + @Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C, work); } + @Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C, work); } @Benchmark public void multTransB() { CommonOps_MT_DSCC.multTransB(A, B, C, work); } @Benchmark public void multAddTransB() { CommonOps_MT_DSCC.multAddTransB(A, B, C, work); } -// @Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); } -// @Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); } -// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A, C); } + @Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); } + @Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); } +// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A_small, C); } public static void main( String[] args ) throws RunnerException { Options opt = new OptionsBuilder() diff --git a/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java b/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java index 91fcd9cd5..55979cd7f 100644 --- a/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java +++ b/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java @@ -188,13 +188,17 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu * @param B Dense Matrix * @param outputC Dense Matrix */ - public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) { + public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC, + @Nullable DGrowArray work ) { if (A.numRows != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); + if (work == null) + work = new DGrowArray(); + outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols); - ImplMultiplication_DSCC.multTransA(A, B, outputC); + ImplMultiplication_DSCC.multTransA(A, B, outputC, work); return outputC; } @@ -202,13 +206,17 @@ public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab /** *

C = C + AT*B

*/ - public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) { + public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC, + @Nullable DGrowArray work ) { if (A.numRows != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); if (A.numCols != outputC.numRows || B.numCols != outputC.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); - ImplMultiplication_DSCC.multAddTransA(A, B, outputC); + if (work == null) + work = new DGrowArray(); + + ImplMultiplication_DSCC.multAddTransA(A, B, outputC, work); } /** @@ -218,12 +226,16 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj * @param B Dense Matrix * @param outputC Dense Matrix */ - public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) { + public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC, + @Nullable DGrowArray work ) { if (A.numCols != B.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows); - ImplMultiplication_DSCC.multTransB(A, B, outputC); + if (work == null) + work = new DGrowArray(); + + ImplMultiplication_DSCC.multTransB(A, B, outputC, work); return outputC; } @@ -231,13 +243,17 @@ public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab /** *

C = C + A*BT

*/ - public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) { + public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC, + @Nullable DGrowArray work ) { if (A.numCols != B.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); if (A.numRows != outputC.numRows || B.numRows != outputC.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); - ImplMultiplication_DSCC.multAddTransB(A, B, outputC); + if (work == null) + work = new DGrowArray(); + + ImplMultiplication_DSCC.multAddTransB(A, B, outputC, work); } /** diff --git a/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java b/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java index e41befed9..68f744bee 100644 --- a/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java +++ b/main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java @@ -56,6 +56,9 @@ public static DMatrixSparseCSC mult( DMatrixSparseCSC A, DMatrixSparseCSC B, @Nu throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A, A.numRows, B.numCols); + if (listWork == null) + listWork = new GrowArray<>(Workspace_MT_DSCC::new); + ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork); return outputC; @@ -80,6 +83,9 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A, A.numRows, A.numCols); + if (listWork == null) + listWork = new GrowArray<>(Workspace_MT_DSCC::new); + ImplCommonOps_MT_DSCC.add(alpha, A, beta, B, outputC, listWork); return outputC; @@ -93,12 +99,14 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet * @param outputC Dense Matrix */ public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC, - @Nullable GrowArray listWork ) { + @Nullable GrowArray workArrays ) { if (A.numCols != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A.numRows, B.numCols); + if (workArrays == null) + workArrays = new GrowArray<>(DGrowArray::new); - ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork); + ImplMultiplication_MT_DSCC.mult(A, B, outputC, workArrays); return outputC; } @@ -107,13 +115,16 @@ public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMa *

C = C + AT*B

*/ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC, - @Nullable GrowArray listWork ) { + @Nullable GrowArray workArrays ) { if (A.numCols != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); if (A.numRows != outputC.numRows || B.numCols != outputC.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); - ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, listWork); + if (workArrays == null) + workArrays = new GrowArray<>(DGrowArray::new); + + ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, workArrays); } /** @@ -123,13 +134,17 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu * @param B Dense Matrix * @param outputC Dense Matrix */ - public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) { + public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC, + @Nullable GrowArray workArray ) { if (A.numRows != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols); - ImplMultiplication_MT_DSCC.multTransA(A, B, outputC); + if (workArray == null) + workArray = new GrowArray<>(DGrowArray::new); + + ImplMultiplication_MT_DSCC.multTransA(A, B, outputC, workArray); return outputC; } @@ -137,13 +152,17 @@ public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab /** *

C = C + AT*B

*/ - public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) { + public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC, + @Nullable GrowArray workArray ) { if (A.numRows != B.numRows) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); if (A.numCols != outputC.numRows || B.numCols != outputC.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); - ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC); + if (workArray == null) + workArray = new GrowArray<>(DGrowArray::new); + + ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC, workArray); } /** @@ -154,12 +173,15 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj * @param outputC Dense Matrix */ public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC, - @Nullable GrowArray listWork ) { + @Nullable GrowArray workArrays ) { if (A.numCols != B.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows); - ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, listWork); + if (workArrays == null) + workArrays = new GrowArray<>(DGrowArray::new); + + ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, workArrays); return outputC; } @@ -168,12 +190,44 @@ public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab *

C = C + A*BT

*/ public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC, - @Nullable GrowArray listWork ) { + @Nullable GrowArray workArrays ) { if (A.numCols != B.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); if (A.numRows != outputC.numRows || B.numRows != outputC.numCols) throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); - ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, listWork); + if (workArrays == null) + workArrays = new GrowArray<>(DGrowArray::new); + + ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, workArrays); + } + + /** + * Performs matrix multiplication. C = AT*BT + * + * @param A Matrix + * @param B Dense Matrix + * @param outputC Dense Matrix + */ + public static DMatrixRMaj multTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) { + if (A.numRows != B.numCols) + throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); + outputC = reshapeOrDeclare(outputC, A.numCols, B.numRows); + + ImplMultiplication_MT_DSCC.multTransAB(A, B, outputC); + + return outputC; + } + + /** + *

C = C + AT*BT

+ */ + public static void multAddTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) { + if (A.numRows != B.numCols) + throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B)); + if (A.numCols != outputC.numRows || B.numRows != outputC.numCols) + throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC)); + + ImplMultiplication_MT_DSCC.multAddTransAB(A, B, outputC); } } diff --git a/main/ejml-dsparse/src/org/ejml/sparse/csc/misc/ImplCommonOps_MT_DSCC.java b/main/ejml-dsparse/src/org/ejml/sparse/csc/misc/ImplCommonOps_MT_DSCC.java index ecabfc271..5c29784f1 100644 --- a/main/ejml-dsparse/src/org/ejml/sparse/csc/misc/ImplCommonOps_MT_DSCC.java +++ b/main/ejml-dsparse/src/org/ejml/sparse/csc/misc/ImplCommonOps_MT_DSCC.java @@ -21,7 +21,6 @@ import org.ejml.concurrency.EjmlConcurrency; import org.ejml.data.DMatrixSparseCSC; import org.ejml.sparse.csc.mult.Workspace_MT_DSCC; -import org.jetbrains.annotations.Nullable; import pabeles.concurrency.GrowArray; import static org.ejml.UtilEjml.adjust; @@ -46,10 +45,7 @@ public class ImplCommonOps_MT_DSCC { * @param listWork (Optional) Storage for internal workspace. Can be null. */ public static void add( double alpha, DMatrixSparseCSC A, double beta, DMatrixSparseCSC B, DMatrixSparseCSC C, - @Nullable GrowArray listWork ) { - if (listWork == null) - listWork = new GrowArray<>(Workspace_MT_DSCC::new); - + GrowArray listWork ) { // Break the problem up into blocks of columns and process them independently EjmlConcurrency.loopBlocks(0, A.numCols, listWork, ( workspace, col0, col1 ) -> { DMatrixSparseCSC workC = workspace.mat; diff --git a/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_DSCC.java b/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_DSCC.java index 9971a4a54..7a8c03830 100644 --- a/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_DSCC.java +++ b/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_DSCC.java @@ -22,6 +22,7 @@ import org.ejml.data.DMatrixRMaj; import org.ejml.data.DMatrixSparseCSC; import org.ejml.data.IGrowArray; +import org.ejml.ops.DOperatorBinary; import org.jetbrains.annotations.Nullable; import java.util.Arrays; @@ -173,29 +174,24 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { } } - public static void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { - - // C(i,j) = sum_k A(k,i) * B(k,j) - for (int j = 0; j < B.numCols; j++) { - - for (int i = 0; i < A.numCols; i++) { - int idx0 = A.col_idx[i]; - int idx1 = A.col_idx[i + 1]; - - double sum = 0; - for (int indexA = idx0; indexA < idx1; indexA++) { - int rowK = A.nz_rows[indexA]; - sum += A.nz_values[indexA]*B.data[rowK*B.numCols + j]; - } + public static void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, DGrowArray workArray ) { + multTransA(A, B, C, workArray, ( a, b ) -> b); + } - C.data[i*C.numCols + j] = sum; - } - } + public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, DGrowArray workArray ) { + multTransA(A, B, C, workArray, Double::sum); } - public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + public static void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, DGrowArray workArray, + DOperatorBinary op ) { + double[] work = workArray.reshape(B.numRows).data; + // C(i,j) = sum_k A(k,i) * B(k,j) for (int j = 0; j < B.numCols; j++) { + // local copy of row to avoid cache misses + for (int k = 0; k < B.numRows; k++) { + work[k] = B.data[k*B.numCols + j]; + } for (int i = 0; i < A.numCols; i++) { int idx0 = A.col_idx[i]; @@ -203,55 +199,77 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj double sum = 0; for (int indexA = idx0; indexA < idx1; indexA++) { - int rowK = A.nz_rows[indexA]; - sum += A.nz_values[indexA]*B.data[rowK*B.numCols + j]; + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*work[k]; } - C.data[i*C.numCols + j] += sum; + C.data[i*C.numCols + j] = op.apply(C.data[i*C.numCols + j], sum); } } } - public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { - + public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, DGrowArray workArray ) { C.zero(); - multAddTransB(A, B, C); + multAddTransB(A, B, C, workArray); } - public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, DGrowArray workArray ) { + double[] work = workArray.reshape(B.numRows).data; // C(i,j) = sum_k A(i,k) * B(j,k) for (int k = 0; k < A.numCols; k++) { + // local copy of row to avoid cache misses + for (int j = 0; j < B.numRows; j++) { + work[j] = B.data[j*B.numCols + k]; + } + int idx0 = A.col_idx[k]; int idx1 = A.col_idx[k + 1]; for (int indexA = idx0; indexA < idx1; indexA++) { for (int j = 0; j < B.numRows; j++) { int i = A.nz_rows[indexA]; - C.data[i*C.numCols + j] += A.nz_values[indexA]*B.data[j*B.numCols + k]; + C.data[i*C.numCols + j] += A.nz_values[indexA]*work[j]; } } } } public static void multTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { - C.zero(); - multAddTransAB(A, B, C); + // C(i,j) = sum_k A(k,i) * B(j,K) + for (int j = 0; j < B.numRows; j++) { + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; + + final int indexRowB = j*B.numCols; + + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*B.data[indexRowB + k]; + } + + C.data[i*C.numCols + j] = sum; + } + } } public static void multAddTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { // C(i,j) = sum_k A(k,i) * B(j,K) - for (int i = 0; i < A.numCols; i++) { - int idx0 = A.col_idx[i]; - int idx1 = A.col_idx[i + 1]; + for (int j = 0; j < B.numRows; j++) { + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; - for (int indexA = idx0; indexA < idx1; indexA++) { - for (int j = 0; j < B.numRows; j++) { - int indexB = j*B.numCols; + final int indexRowB = j*B.numCols; + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { int k = A.nz_rows[indexA]; - - C.data[i*C.numCols + j] += A.nz_values[indexA]*B.data[indexB + k]; + sum += A.nz_values[indexA]*B.data[indexRowB + k]; } + + C.data[i*C.numCols + j] += sum; } } } diff --git a/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_MT_DSCC.java b/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_MT_DSCC.java index f319530ca..3e0f07110 100644 --- a/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_MT_DSCC.java +++ b/main/ejml-dsparse/src/org/ejml/sparse/csc/mult/ImplMultiplication_MT_DSCC.java @@ -23,7 +23,6 @@ import org.ejml.data.DGrowArray; import org.ejml.data.DMatrixRMaj; import org.ejml.data.DMatrixSparseCSC; -import org.jetbrains.annotations.Nullable; import pabeles.concurrency.GrowArray; import java.util.Arrays; @@ -49,10 +48,7 @@ public class ImplMultiplication_MT_DSCC { * @param listWork (Optional) Storage for internal workspace. Can be null. */ public static void mult( DMatrixSparseCSC A, DMatrixSparseCSC B, DMatrixSparseCSC C, - @Nullable GrowArray listWork ) { - if (listWork == null) - listWork = new GrowArray<>(Workspace_MT_DSCC::new); - + GrowArray listWork ) { // Break the problem up into blocks of columns and process them independently EjmlConcurrency.loopBlocks(0, B.numCols, listWork, ( workspace, bj0, bj1 ) -> { DMatrixSparseCSC workC = workspace.mat; @@ -131,20 +127,17 @@ public static void stitchMatrix( DMatrixSparseCSC out, int numRows, int numCols, } public static void mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, - @Nullable GrowArray listWork ) { + GrowArray listWork ) { mult(A, B, C, false, listWork); } public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, - @Nullable GrowArray listWork ) { + GrowArray listWork ) { mult(A, B, C, true, listWork); } public static void mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, boolean add, - @Nullable GrowArray listWork ) { - if (listWork == null) - listWork = new GrowArray<>(DGrowArray::new); - + GrowArray listWork ) { // Break the problem up into blocks of columns and process them independently EjmlConcurrency.loopBlocks(0, B.numCols, listWork, ( gwork, bj0, bj1 ) -> { // same array to store column in A and B. This is done to reduce cache misses in B and C access @@ -190,57 +183,74 @@ public static void mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, boole }); } - public static void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + public static void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, + GrowArray listWork ) { // C(i,j) = sum_k A(k,i) * B(k,j) - EjmlConcurrency.loopFor(0, B.numCols, j -> { - for (int i = 0; i < A.numCols; i++) { - int idx0 = A.col_idx[i]; - int idx1 = A.col_idx[i + 1]; + EjmlConcurrency.loopBlocks(0, B.numCols, listWork, ( gwork, j0, j1 ) -> { + // Local copy of column in A to reduce cache misses + double[] work = gwork.reshape(B.numRows).data; - double sum = 0; - for (int indexA = idx0; indexA < idx1; indexA++) { - int rowK = A.nz_rows[indexA]; - sum += A.nz_values[indexA]*B.data[rowK*B.numCols + j]; + for (int j = j0; j < j1; j++) { + for (int k = 0; k < B.numRows; k++) { + work[k] = B.data[k*B.numCols + j]; } - C.data[i*C.numCols + j] = sum; + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; + + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*work[k]; +// sum += A.nz_values[indexA]*B.data[k*B.numCols + j]; + } + + C.data[i*C.numCols + j] = sum; + } } }); } - public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, + GrowArray listWork ) { // C(i,j) = sum_k A(k,i) * B(k,j) - EjmlConcurrency.loopFor(0, B.numCols, j -> { - for (int i = 0; i < A.numCols; i++) { - int idx0 = A.col_idx[i]; - int idx1 = A.col_idx[i + 1]; + EjmlConcurrency.loopBlocks(0, B.numCols, listWork, ( gwork, j0, j1 ) -> { + // Local copy of column in A to reduce cache misses + double[] work = gwork.reshape(B.numRows).data; - double sum = 0; - for (int indexA = idx0; indexA < idx1; indexA++) { - int rowK = A.nz_rows[indexA]; - sum += A.nz_values[indexA]*B.data[rowK*B.numCols + j]; + for (int j = j0; j < j1; j++) { + for (int k = 0; k < B.numRows; k++) { + work[k] = B.data[k*B.numCols + j]; } - C.data[i*C.numCols + j] += sum; + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; + + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*work[k]; +// sum += A.nz_values[indexA]*B.data[k*B.numCols + j]; + } + + C.data[i*C.numCols + j] += sum; + } } }); } - public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, - @Nullable GrowArray listWork ) { - mult(A, B, C, false, listWork); + public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, GrowArray listWork ) { + multTransB(A, B, C, false, listWork); } - public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, - @Nullable GrowArray listWork ) { + public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, GrowArray listWork ) { multTransB(A, B, C, true, listWork); } public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, boolean add, - @Nullable GrowArray listWork ) { - if (listWork == null) - listWork = new GrowArray<>(DGrowArray::new); - + GrowArray listWork ) { // Break the problem up into blocks of columns and process them independently EjmlConcurrency.loopBlocks(0, B.numRows, listWork, ( gwork, bj0, bj1 ) -> { // Local copy of column in A to reduce cache misses @@ -280,4 +290,44 @@ public static void multTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, } }); } + + public static void multTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + // C(i,j) = sum_k A(k,i) * B(j,k) + EjmlConcurrency.loopFor(0, B.numRows, j -> { + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; + + final int indexRowB = j*B.numCols; + + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*B.data[indexRowB + k]; + } + + C.data[i*C.numCols + j] = sum; + } + }); + } + + public static void multAddTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C ) { + // C(i,j) = sum_k A(k,i) * B(j,k) + EjmlConcurrency.loopFor(0, B.numRows, j -> { + for (int i = 0; i < A.numCols; i++) { + int idx0 = A.col_idx[i]; + int idx1 = A.col_idx[i + 1]; + + final int indexRowB = j*B.numCols; + + double sum = 0; + for (int indexA = idx0; indexA < idx1; indexA++) { + int k = A.nz_rows[indexA]; + sum += A.nz_values[indexA]*B.data[indexRowB + k]; + } + + C.data[i*C.numCols + j] += sum; + } + }); + } } diff --git a/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_DSCC.java b/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_DSCC.java index de67fe756..61f4fb032 100644 --- a/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_DSCC.java +++ b/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_DSCC.java @@ -211,11 +211,11 @@ private void check_s_d_mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, b CommonOps_DSCC.multAddTransAB(A_t, B_t, C); CommonOps_DDRM.multAddTransAB(denseA_t, B_t, expected); } else { - CommonOps_DSCC.multAddTransA(A_t, B, C); + CommonOps_DSCC.multAddTransA(A_t, B, C, null); CommonOps_DDRM.multAddTransA(denseA_t, B, expected); } } else if (transB) { - CommonOps_DSCC.multAddTransB(A, B_t, C); + CommonOps_DSCC.multAddTransB(A, B_t, C, null); CommonOps_DDRM.multAddTransB(denseA, B_t, expected); } else { CommonOps_DSCC.multAdd(A, B, C); @@ -227,11 +227,11 @@ private void check_s_d_mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, b CommonOps_DSCC.multTransAB(A_t, B_t, C); CommonOps_DDRM.multTransAB(denseA_t, B_t, expected); } else { - CommonOps_DSCC.multTransA(A_t, B, C); + CommonOps_DSCC.multTransA(A_t, B, C, null); CommonOps_DDRM.multTransA(denseA_t, B, expected); } } else if (transB) { - CommonOps_DSCC.multTransB(A, B_t, C); + CommonOps_DSCC.multTransB(A, B_t, C, null); CommonOps_DDRM.multTransB(denseA, B_t, expected); } else { CommonOps_DSCC.mult(A, B, C); @@ -547,7 +547,7 @@ public void maxAbsCols() { int nz_a = RandomMatrices_DSCC.nonzero(rows, cols, 0.05, 0.8, rand); DMatrixSparseCSC A = RandomMatrices_DSCC.rectangle(rows, cols, nz_a, rand); - DMatrixRMaj values = new DMatrixRMaj(1,1); + DMatrixRMaj values = new DMatrixRMaj(1, 1); CommonOps_DSCC.maxAbsCols(A, values); @@ -1509,7 +1509,7 @@ public void applyColumnWise() { DMatrixSparseCSC B = CommonOps_DSCC.applyColumnIdx(A, applyFunc, null); A.createCoordinateIterator() - .forEachRemaining(entry -> assertEquals(B.get(entry.row, entry.col) ,applyFunc.apply(entry.col, entry.value))); + .forEachRemaining(entry -> assertEquals(B.get(entry.row, entry.col), applyFunc.apply(entry.col, entry.value))); } @Test diff --git a/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_MT_DSCC.java b/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_MT_DSCC.java index 904726237..877c708df 100644 --- a/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_MT_DSCC.java +++ b/main/ejml-dsparse/test/org/ejml/sparse/csc/TestCommonOps_MT_DSCC.java @@ -19,6 +19,7 @@ package org.ejml.sparse.csc; import org.ejml.UtilEjml; +import org.ejml.data.DGrowArray; import org.ejml.data.DMatrixRMaj; import org.ejml.data.DMatrixSparseCSC; import org.ejml.dense.row.CommonOps_DDRM; @@ -26,6 +27,7 @@ import org.ejml.dense.row.RandomMatrices_DDRM; import org.ejml.ops.DConvertMatrixStruct; import org.junit.jupiter.api.Test; +import pabeles.concurrency.GrowArray; import java.util.Random; @@ -39,8 +41,9 @@ class TestCommonOps_MT_DSCC { private final Random rand = new Random(234); - @Test - public void mult_s_s_shapes() { + private GrowArray growArray = new GrowArray<>(DGrowArray::new); + + @Test void mult_s_s_shapes() { // multiple trials to test more sparse structures for (int trial = 0; trial < 50; trial++) { check_s_s_mult( @@ -86,8 +89,7 @@ private void check_s_s_mult( DMatrixSparseCSC A, DMatrixSparseCSC B, DMatrixSpar } } - @Test - public void add_shapes() { + @Test void add_shapes() { check_add( RandomMatrices_DSCC.rectangle(5, 6, 5, rand), RandomMatrices_DSCC.rectangle(5, 6, 5, rand), @@ -143,8 +145,7 @@ private void check_add( DMatrixSparseCSC A, DMatrixSparseCSC B, DMatrixSparseCSC } } - @Test - public void mult_s_d_shapes() { + @Test void mult_s_d_shapes() { check_s_d_mult( RandomMatrices_DSCC.rectangle(5, 6, 5, rand), RandomMatrices_DDRM.rectangle(6, 4, rand), @@ -188,40 +189,34 @@ private void check_s_d_mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, b if (add) { if (transA) { if (transB) { - continue; -// CommonOps_DSCC.multAddTransAB(A_t, B_t, C); -// CommonOps_DDRM.multAddTransAB(denseA_t, B_t, expected); + CommonOps_MT_DSCC.multAddTransAB(A_t, B_t, C); + CommonOps_DDRM.multAddTransAB(denseA_t, B_t, expected); } else { - CommonOps_DSCC.multAddTransA(A_t, B, C); + CommonOps_MT_DSCC.multAddTransA(A_t, B, C, growArray); CommonOps_DDRM.multAddTransA(denseA_t, B, expected); } } else if (transB) { - continue; -// CommonOps_DSCC.multAddTransB(A, B_t, C); -// CommonOps_DDRM.multAddTransB(denseA, B_t, expected); + CommonOps_MT_DSCC.multAddTransB(A, B_t, C, growArray); + CommonOps_DDRM.multAddTransB(denseA, B_t, expected); } else { - continue; -// CommonOps_DSCC.multAdd(A, B, C); -// CommonOps_DDRM.multAdd(denseA, B, expected); + CommonOps_MT_DSCC.multAdd(A, B, C, growArray); + CommonOps_DDRM.multAdd(denseA, B, expected); } } else { if (transA) { if (transB) { - continue; -// CommonOps_DSCC.multTransAB(A_t, B_t, C); -// CommonOps_DDRM.multTransAB(denseA_t, B_t, expected); + CommonOps_MT_DSCC.multTransAB(A_t, B_t, C); + CommonOps_DDRM.multTransAB(denseA_t, B_t, expected); } else { - CommonOps_DSCC.multTransA(A_t, B, C); + CommonOps_MT_DSCC.multTransA(A_t, B, C, growArray); CommonOps_DDRM.multTransA(denseA_t, B, expected); } } else if (transB) { - continue; -// CommonOps_DSCC.multTransB(A, B_t, C); -// CommonOps_DDRM.multTransB(denseA, B_t, expected); + CommonOps_MT_DSCC.multTransB(A, B_t, C, growArray); + CommonOps_DDRM.multTransB(denseA, B_t, expected); } else { - continue; -// CommonOps_DSCC.mult(A, B, C); -// CommonOps_DDRM.mult(denseA, B, expected); + CommonOps_MT_DSCC.mult(A, B, C, growArray); + CommonOps_DDRM.mult(denseA, B, expected); } } @@ -229,9 +224,11 @@ private void check_s_d_mult( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj C, b fail("exception expected"); assertTrue(MatrixFeatures_DDRM.isIdentical(expected, C, UtilEjml.TEST_F64)); - } catch (RuntimeException ignore) { - if (!exception) + } catch (RuntimeException e) { + if (!exception) { + e.printStackTrace(); fail("no exception expected"); + } } } } diff --git a/main/ejml-dsparse/test/org/ejml/sparse/csc/misc/TestImplCommonOps_MT_DSCC.java b/main/ejml-dsparse/test/org/ejml/sparse/csc/misc/TestImplCommonOps_MT_DSCC.java index fce468a23..119eb4a5f 100644 --- a/main/ejml-dsparse/test/org/ejml/sparse/csc/misc/TestImplCommonOps_MT_DSCC.java +++ b/main/ejml-dsparse/test/org/ejml/sparse/csc/misc/TestImplCommonOps_MT_DSCC.java @@ -23,7 +23,9 @@ import org.ejml.sparse.csc.CommonOps_DSCC; import org.ejml.sparse.csc.MatrixFeatures_DSCC; import org.ejml.sparse.csc.RandomMatrices_DSCC; +import org.ejml.sparse.csc.mult.Workspace_MT_DSCC; import org.junit.jupiter.api.Test; +import pabeles.concurrency.GrowArray; import java.util.Random; @@ -35,6 +37,8 @@ class TestImplCommonOps_MT_DSCC { private final Random rand = new Random(324); + private final GrowArray listWork = new GrowArray<>(Workspace_MT_DSCC::new); + @Test void add() { double alpha = 1.5; @@ -48,7 +52,7 @@ void add() { DMatrixSparseCSC cc = c.copy(); ImplCommonOps_DSCC.add(alpha, a, beta, b, c, null, null); - ImplCommonOps_MT_DSCC.add(alpha, a, beta, b, cc, null); + ImplCommonOps_MT_DSCC.add(alpha, a, beta, b, cc, listWork); assertTrue(CommonOps_DSCC.checkStructure(cc)); assertTrue(MatrixFeatures_DSCC.isEqualsSort(c, cc, UtilEjml.TEST_F64)); diff --git a/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_DSCC.java b/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_DSCC.java index ce47cc2e6..159b8e155 100644 --- a/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_DSCC.java +++ b/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_DSCC.java @@ -41,6 +41,7 @@ public class TestImplMultiplication_DSCC { Random rand = new Random(234); + DGrowArray workArray = new DGrowArray(); @Test void mult_s_s() { for (int i = 0; i < 50; i++) { @@ -153,10 +154,10 @@ private void multTransA_s_d( int elementsA, boolean add ) { DMatrixRMaj dense_a = DConvertMatrixStruct.convert(a, (DMatrixRMaj)null); if (add) { - ImplMultiplication_DSCC.multAddTransA(a, b, c); + ImplMultiplication_DSCC.multAddTransA(a, b, c, workArray); CommonOps_DDRM.multAddTransA(dense_a, b, expected_c); } else { - ImplMultiplication_DSCC.multTransA(a, b, c); + ImplMultiplication_DSCC.multTransA(a, b, c, workArray); CommonOps_DDRM.multTransA(dense_a, b, expected_c); } for (int row = 0; row < c.numRows; row++) { @@ -186,10 +187,10 @@ private void multTransB_s_d( int elementsA, boolean add ) { DMatrixRMaj dense_a = DConvertMatrixStruct.convert(a, (DMatrixRMaj)null); if (add) { - ImplMultiplication_DSCC.multAddTransB(a, b, c); + ImplMultiplication_DSCC.multAddTransB(a, b, c, workArray); CommonOps_DDRM.multAddTransB(dense_a, b, expected_c); } else { - ImplMultiplication_DSCC.multTransB(a, b, c); + ImplMultiplication_DSCC.multTransB(a, b, c, workArray); CommonOps_DDRM.multTransB(dense_a, b, expected_c); } for (int row = 0; row < c.numRows; row++) { diff --git a/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_MT_DSCC.java b/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_MT_DSCC.java index b39ccab67..723484889 100644 --- a/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_MT_DSCC.java +++ b/main/ejml-dsparse/test/org/ejml/sparse/csc/mult/TestImplMultiplication_MT_DSCC.java @@ -41,7 +41,9 @@ * @author Peter Abeles */ class TestImplMultiplication_MT_DSCC { - Random rand = new Random(234); + private final Random rand = new Random(234); + private final GrowArray workArrays = new GrowArray<>(DGrowArray::new); + private final GrowArray workSpaceMT = new GrowArray<>(Workspace_MT_DSCC::new); @Test void mult_s_s() { for (int i = 0; i < 50; i++) { @@ -63,7 +65,7 @@ private void mult_s_s( int rowsA, int colsA, int colsB ) { DMatrixSparseCSC found = expected.copy(); ImplMultiplication_DSCC.mult(a, b, expected, null, null); - ImplMultiplication_MT_DSCC.mult(a, b, found, null); + ImplMultiplication_MT_DSCC.mult(a, b, found, workSpaceMT); assertTrue(CommonOps_DSCC.checkStructure(found)); assertTrue(MatrixFeatures_DSCC.isEqualsSort(expected, found, UtilEjml.TEST_F64)); @@ -119,8 +121,8 @@ private void multTransA_s_d( int rowsA, int colsA, int colsB ) { DMatrixRMaj expected = RandomMatrices_DDRM.rectangle(rowsA, colsB, -1, 1, rand); DMatrixRMaj found = expected.copy(); - ImplMultiplication_DSCC.multTransA(a, b, expected); - ImplMultiplication_MT_DSCC.multTransA(a, b, found); + ImplMultiplication_DSCC.multTransA(a, b, expected, workArrays.grow()); + ImplMultiplication_MT_DSCC.multTransA(a, b, found, workArrays); assertTrue(MatrixFeatures_DDRM.isEquals(expected, found, UtilEjml.TEST_F64)); } @@ -140,8 +142,8 @@ private void multAddTransA_s_d( int rowsA, int colsA, int colsB ) { DMatrixRMaj expected = RandomMatrices_DDRM.rectangle(rowsA, colsB, -1, 1, rand); DMatrixRMaj found = expected.copy(); - ImplMultiplication_DSCC.multAddTransA(a, b, expected); - ImplMultiplication_MT_DSCC.multAddTransA(a, b, found); + ImplMultiplication_DSCC.multAddTransA(a, b, expected, workArrays.grow()); + ImplMultiplication_MT_DSCC.multAddTransA(a, b, found, workArrays); assertTrue(MatrixFeatures_DDRM.isEquals(expected, found, UtilEjml.TEST_F64)); } @@ -152,9 +154,9 @@ private void multAddTransA_s_d( int rowsA, int colsA, int colsB ) { multTransB_s_d(15, false); multTransB_s_d(4, false); -// multTransB_s_d(24, true); -// multTransB_s_d(15, true); -// multTransB_s_d(4, true); + multTransB_s_d(24, true); + multTransB_s_d(15, true); + multTransB_s_d(4, true); } } @@ -168,7 +170,7 @@ private void multTransB_s_d( int elementsA, boolean add ) { GrowArray work = new GrowArray<>(DGrowArray::new); if (add) { -// ImplSparseSparseMult_MT_DSCC.multAddTransB(a, b, c); + ImplMultiplication_MT_DSCC.multAddTransB(a, b, c, work); CommonOps_DDRM.multAddTransB(dense_a, b, expected_c); } else { ImplMultiplication_MT_DSCC.multTransB(a, b, c, false, work); @@ -180,4 +182,38 @@ private void multTransB_s_d( int elementsA, boolean add ) { } } } + + @Test void multTransAB_s_d() { + for (int i = 0; i < 10; i++) { + multTransAB_s_d(24, false); + multTransAB_s_d(15, false); + multTransAB_s_d(4, false); + + multTransAB_s_d(24, true); + multTransAB_s_d(15, true); + multTransAB_s_d(4, true); + } + } + + private void multTransAB_s_d( int elementsA, boolean add ) { + DMatrixSparseCSC a = RandomMatrices_DSCC.rectangle(6, 4, elementsA, -1, 1, rand); + DMatrixRMaj b = RandomMatrices_DDRM.rectangle(5, 6, -1, 1, rand); + DMatrixRMaj c = RandomMatrices_DDRM.rectangle(4, 5, -1, 1, rand); + DMatrixRMaj expected_c = c.copy(); + DMatrixRMaj dense_a = DConvertMatrixStruct.convert(a, (DMatrixRMaj)null); + + if (add) { + ImplMultiplication_MT_DSCC.multAddTransAB(a, b, c); + CommonOps_DDRM.multAddTransAB(dense_a, b, expected_c); + } else { + ImplMultiplication_MT_DSCC.multTransAB(a, b, c); + CommonOps_DDRM.multTransAB(dense_a, b, expected_c); + } + + for (int row = 0; row < c.numRows; row++) { + for (int col = 0; col < c.numCols; col++) { + assertEquals(expected_c.get(row, col), c.get(row, col), UtilEjml.TEST_F64, row + " " + col); + } + } + } } \ No newline at end of file diff --git a/main/ejml-simple/src/org/ejml/simple/ops/SimpleOperations_DSCC.java b/main/ejml-simple/src/org/ejml/simple/ops/SimpleOperations_DSCC.java index 1dbc16e65..179f60267 100644 --- a/main/ejml-simple/src/org/ejml/simple/ops/SimpleOperations_DSCC.java +++ b/main/ejml-simple/src/org/ejml/simple/ops/SimpleOperations_DSCC.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2009-2020, Peter Abeles. All Rights Reserved. + * Copyright (c) 2020, Peter Abeles. All Rights Reserved. * * This file is part of Efficient Java Matrix Library (EJML). * @@ -94,7 +94,7 @@ public void extractDiag( DMatrixSparseCSC input, DMatrixRMaj output ) { @Override public void multTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj output ) { - CommonOps_DSCC.multTransA(A, B, output); + CommonOps_DSCC.multTransA(A, B, output, null); } @Override