Skip to content

Commit

Permalink
CommonOps_DSCC
Browse files Browse the repository at this point in the history
- Speed up more mult variants
  • Loading branch information
lessthanoptimal committed Nov 4, 2020
1 parent 09db214 commit b838693
Show file tree
Hide file tree
Showing 14 changed files with 343 additions and 165 deletions.
3 changes: 2 additions & 1 deletion change.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ Date format: year/month/day
* Removed the functions below since they had a runtime complexity of O(N^2) relative to matrix size instead of O(N)
- multTransA(S,S,S), multTransB(S,S,S), innerProductLower(S,S,S)
- Thanks Florentin Dörre for first noticing the performance issue
* Speed up multTransAB(S,D,D), multTransA(S,D,D), multTransB(S,D,D) by a large margin
- DMatrixSparseCSC
* If sorted a binary search is used to lookup rows. Thanks Florentin Dörre.
* If sorted, a binary search is used to lookup rows. Thanks Florentin Dörre.
- ReadMatrixCsv
* Thanks DEDZTBH for fixing an indexing error when reading complex data types
- Added Concurrent Algorithms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.ejml.sparse.csc;

import org.ejml.data.DGrowArray;
import org.ejml.data.DMatrixRMaj;
import org.ejml.data.DMatrixSparseCSC;
import org.ejml.dense.row.RandomMatrices_DDRM;
Expand Down Expand Up @@ -47,27 +48,30 @@ public class BenchmarkMatrixMultDense_DSCC {
@Param({"100000"})
private int elementCount;

DMatrixSparseCSC A;
DMatrixSparseCSC A,A_small;
DMatrixRMaj B = new DMatrixRMaj(1, 1);
DMatrixRMaj C = new DMatrixRMaj(1, 1);

DGrowArray work = new DGrowArray();

@Setup
public void setup() {
Random rand = new Random(2345);
A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand);
A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand);
B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand);
C = B.create(dimension, dimension);
}

@Benchmark public void mult() { CommonOps_DSCC.mult(A, B, C); }
@Benchmark public void multAdd() { CommonOps_DSCC.multAdd(A, B, C); }
@Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C); }
@Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C); }
@Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C); }
@Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C); }
@Benchmark public void multTransA() { CommonOps_DSCC.multTransA(A, B, C, work); }
@Benchmark public void multAddTransA() { CommonOps_DSCC.multAddTransA(A, B, C, work); }
@Benchmark public void multTransB() { CommonOps_DSCC.multTransB(A, B, C, work); }
@Benchmark public void multAddTransB() { CommonOps_DSCC.multAddTransB(A, B, C, work); }
@Benchmark public void multTransAB() { CommonOps_DSCC.multTransAB(A, B, C); }
@Benchmark public void multAddTransAB() { CommonOps_DSCC.multAddTransAB(A, B, C); }
@Benchmark public void invert() { CommonOps_DSCC.invert(A, C); }
@Benchmark public void invert() { CommonOps_DSCC.invert(A_small, C); }

public static void main( String[] args ) throws RunnerException {
Options opt = new OptionsBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class BenchmarkMatrixMultDense_MT_DSCC {
@Param({"100000"})
private int elementCount;

DMatrixSparseCSC A;
DMatrixSparseCSC A,A_small;
DMatrixRMaj B = new DMatrixRMaj(1, 1);
DMatrixRMaj C = new DMatrixRMaj(1, 1);

Expand All @@ -59,19 +59,20 @@ public class BenchmarkMatrixMultDense_MT_DSCC {
public void setup() {
Random rand = new Random(2345);
A = RandomMatrices_DSCC.rectangle(dimension, dimension, elementCount, rand);
A_small = RandomMatrices_DSCC.rectangle(dimension/4, dimension/4, elementCount/4, rand);
B = RandomMatrices_DDRM.rectangle(dimension, dimension, -1, 1, rand);
C = B.create(dimension, dimension);
}

@Benchmark public void mult() { CommonOps_MT_DSCC.mult(A, B, C, work); }
@Benchmark public void multAdd() { CommonOps_MT_DSCC.multAdd(A, B, C, work); }
@Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C); }
@Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C); }
@Benchmark public void multTransA() { CommonOps_MT_DSCC.multTransA(A, B, C, work); }
@Benchmark public void multAddTransA() { CommonOps_MT_DSCC.multAddTransA(A, B, C, work); }
@Benchmark public void multTransB() { CommonOps_MT_DSCC.multTransB(A, B, C, work); }
@Benchmark public void multAddTransB() { CommonOps_MT_DSCC.multAddTransB(A, B, C, work); }
// @Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); }
// @Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); }
// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A, C); }
@Benchmark public void multTransAB() { CommonOps_MT_DSCC.multTransAB(A, B, C); }
@Benchmark public void multAddTransAB() { CommonOps_MT_DSCC.multAddTransAB(A, B, C); }
// @Benchmark public void invert() { CommonOps_MT_DSCC.invert(A_small, C); }

public static void main( String[] args ) throws RunnerException {
Options opt = new OptionsBuilder()
Expand Down
32 changes: 24 additions & 8 deletions main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_DSCC.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,27 +188,35 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu
* @param B Dense Matrix
* @param outputC Dense Matrix
*/
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
@Nullable DGrowArray work ) {
if (A.numRows != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));

if (work == null)
work = new DGrowArray();

outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols);

ImplMultiplication_DSCC.multTransA(A, B, outputC);
ImplMultiplication_DSCC.multTransA(A, B, outputC, work);

return outputC;
}

/**
* <p>C = C + A<sup>T</sup>*B</p>
*/
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
@Nullable DGrowArray work ) {
if (A.numRows != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numCols != outputC.numRows || B.numCols != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_DSCC.multAddTransA(A, B, outputC);
if (work == null)
work = new DGrowArray();

ImplMultiplication_DSCC.multAddTransA(A, B, outputC, work);
}

/**
Expand All @@ -218,26 +226,34 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj
* @param B Dense Matrix
* @param outputC Dense Matrix
*/
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
@Nullable DGrowArray work ) {
if (A.numCols != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows);

ImplMultiplication_DSCC.multTransB(A, B, outputC);
if (work == null)
work = new DGrowArray();

ImplMultiplication_DSCC.multTransB(A, B, outputC, work);

return outputC;
}

/**
* <p>C = C + A*B<sup>T</sup></p>
*/
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
@Nullable DGrowArray work ) {
if (A.numCols != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numRows != outputC.numRows || B.numRows != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_DSCC.multAddTransB(A, B, outputC);
if (work == null)
work = new DGrowArray();

ImplMultiplication_DSCC.multAddTransB(A, B, outputC, work);
}

/**
Expand Down
78 changes: 66 additions & 12 deletions main/ejml-dsparse/src/org/ejml/sparse/csc/CommonOps_MT_DSCC.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public static DMatrixSparseCSC mult( DMatrixSparseCSC A, DMatrixSparseCSC B, @Nu
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A, A.numRows, B.numCols);

if (listWork == null)
listWork = new GrowArray<>(Workspace_MT_DSCC::new);

ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork);

return outputC;
Expand All @@ -80,6 +83,9 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A, A.numRows, A.numCols);

if (listWork == null)
listWork = new GrowArray<>(Workspace_MT_DSCC::new);

ImplCommonOps_MT_DSCC.add(alpha, A, beta, B, outputC, listWork);

return outputC;
Expand All @@ -93,12 +99,14 @@ public static DMatrixSparseCSC add( double alpha, DMatrixSparseCSC A, double bet
* @param outputC Dense Matrix
*/
public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> listWork ) {
@Nullable GrowArray<DGrowArray> workArrays ) {
if (A.numCols != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A.numRows, B.numCols);
if (workArrays == null)
workArrays = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.mult(A, B, outputC, listWork);
ImplMultiplication_MT_DSCC.mult(A, B, outputC, workArrays);

return outputC;
}
Expand All @@ -107,13 +115,16 @@ public static DMatrixRMaj mult( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMa
* <p>C = C + A<sup>T</sup>*B</p>
*/
public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> listWork ) {
@Nullable GrowArray<DGrowArray> workArrays ) {
if (A.numCols != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numRows != outputC.numRows || B.numCols != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, listWork);
if (workArrays == null)
workArrays = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.multAdd(A, B, outputC, workArrays);
}

/**
Expand All @@ -123,27 +134,35 @@ public static void multAdd( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outpu
* @param B Dense Matrix
* @param outputC Dense Matrix
*/
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC ) {
public static DMatrixRMaj multTransA( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> workArray ) {
if (A.numRows != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));

outputC = reshapeOrDeclare(outputC, A.numCols, B.numCols);

ImplMultiplication_MT_DSCC.multTransA(A, B, outputC);
if (workArray == null)
workArray = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.multTransA(A, B, outputC, workArray);

return outputC;
}

/**
* <p>C = C + A<sup>T</sup>*B</p>
*/
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> workArray ) {
if (A.numRows != B.numRows)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numCols != outputC.numRows || B.numCols != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC);
if (workArray == null)
workArray = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.multAddTransA(A, B, outputC, workArray);
}

/**
Expand All @@ -154,12 +173,15 @@ public static void multAddTransA( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj
* @param outputC Dense Matrix
*/
public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullable DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> listWork ) {
@Nullable GrowArray<DGrowArray> workArrays ) {
if (A.numCols != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A.numRows, B.numRows);

ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, listWork);
if (workArrays == null)
workArrays = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.multTransB(A, B, outputC, workArrays);

return outputC;
}
Expand All @@ -168,12 +190,44 @@ public static DMatrixRMaj multTransB( DMatrixSparseCSC A, DMatrixRMaj B, @Nullab
* <p>C = C + A*B<sup>T</sup></p>
*/
public static void multAddTransB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC,
@Nullable GrowArray<DGrowArray> listWork ) {
@Nullable GrowArray<DGrowArray> workArrays ) {
if (A.numCols != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numRows != outputC.numRows || B.numRows != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, listWork);
if (workArrays == null)
workArrays = new GrowArray<>(DGrowArray::new);

ImplMultiplication_MT_DSCC.multAddTransB(A, B, outputC, workArrays);
}

/**
* Performs matrix multiplication. C = A<sup>T</sup>*B<sup>T</sup>
*
* @param A Matrix
* @param B Dense Matrix
* @param outputC Dense Matrix
*/
public static DMatrixRMaj multTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
if (A.numRows != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
outputC = reshapeOrDeclare(outputC, A.numCols, B.numRows);

ImplMultiplication_MT_DSCC.multTransAB(A, B, outputC);

return outputC;
}

/**
* <p>C = C + A<sup>T</sup>*B<sup>T</sup></p>
*/
public static void multAddTransAB( DMatrixSparseCSC A, DMatrixRMaj B, DMatrixRMaj outputC ) {
if (A.numRows != B.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B));
if (A.numCols != outputC.numRows || B.numRows != outputC.numCols)
throw new MatrixDimensionException("Inconsistent matrix shapes. " + stringShapes(A, B, outputC));

ImplMultiplication_MT_DSCC.multAddTransAB(A, B, outputC);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.ejml.concurrency.EjmlConcurrency;
import org.ejml.data.DMatrixSparseCSC;
import org.ejml.sparse.csc.mult.Workspace_MT_DSCC;
import org.jetbrains.annotations.Nullable;
import pabeles.concurrency.GrowArray;

import static org.ejml.UtilEjml.adjust;
Expand All @@ -46,10 +45,7 @@ public class ImplCommonOps_MT_DSCC {
* @param listWork (Optional) Storage for internal workspace. Can be null.
*/
public static void add( double alpha, DMatrixSparseCSC A, double beta, DMatrixSparseCSC B, DMatrixSparseCSC C,
@Nullable GrowArray<Workspace_MT_DSCC> listWork ) {
if (listWork == null)
listWork = new GrowArray<>(Workspace_MT_DSCC::new);

GrowArray<Workspace_MT_DSCC> listWork ) {
// Break the problem up into blocks of columns and process them independently
EjmlConcurrency.loopBlocks(0, A.numCols, listWork, ( workspace, col0, col1 ) -> {
DMatrixSparseCSC workC = workspace.mat;
Expand Down
Loading

0 comments on commit b838693

Please sign in to comment.