diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/NaiveBayesDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/NaiveBayesDALImpl.scala index 5a274cf7f..b497ba1b5 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/NaiveBayesDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/NaiveBayesDALImpl.scala @@ -89,7 +89,7 @@ class NaiveBayesDALImpl(val uid: String, OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() // Make sure there is only one result from rank 0 assert(results.length == 1) diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/RandomForestClassifierDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/RandomForestClassifierDALImpl.scala index cb3ef3e52..772b1b281 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/RandomForestClassifierDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/classification/RandomForestClassifierDALImpl.scala @@ -124,7 +124,7 @@ class RandomForestClassifierDALImpl(val uid: String, } OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() rfcTimer.record("Training") rfcTimer.print() diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/clustering/KMeansDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/clustering/KMeansDALImpl.scala index 3a23a1d65..6f4b55547 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/clustering/KMeansDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/clustering/KMeansDALImpl.scala @@ -95,16 +95,18 @@ class KMeansDALImpl(var nClusters: Int, } OneCCL.cleanup() ret - }.collect() + } + results.count() + val barrierRDD = results.barrier().mapPartitions(iter => iter).collect() // Make sure there is only one result from rank 0 - assert(results.length == 1) + assert(barrierRDD.length == 1) kmeansTimer.record("Training") kmeansTimer.print() - val centerVectors = results(0)._1 - val totalCost = results(0)._2 - val iterationNum = results(0)._3 + val centerVectors = barrierRDD(0)._1 + val totalCost = barrierRDD(0)._2 + val iterationNum = barrierRDD(0)._3 if (iterationNum == maxIterations) { logInfo(s"KMeans reached the max number of iterations: $maxIterations.") diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/feature/PCADALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/feature/PCADALImpl.scala index 4f0807abd..2133e64e4 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/feature/PCADALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/feature/PCADALImpl.scala @@ -45,6 +45,7 @@ class PCADALImpl(val k: Int, def train(data: RDD[Vector]): PCADALModel = { val pcaTimer = new Utils.AlgoTimeMetrics("PCA") val normalizedData = normalizeData(data) + val sparkContext = normalizedData.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val computeDevice = Common.ComputeDevice.getDeviceByName(useDevice) @@ -104,7 +105,7 @@ class PCADALImpl(val k: Int, } OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() pcaTimer.record("Training") pcaTimer.print() diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/recommendation/ALSDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/recommendation/ALSDALImpl.scala index f8c781caf..78783c566 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/recommendation/ALSDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/recommendation/ALSDALImpl.scala @@ -106,7 +106,7 @@ class ALSDALImpl[@specialized(Int, Long) ID: ClassTag]( data: RDD[Rating[ID]], result ) Iterator(result) - }.cache() + }.cache().barrier().mapPartitions(iter => iter) val usersFactorsRDD = results .mapPartitionsWithIndex { (index: Int, partiton: Iterator[ALSResult]) => @@ -127,7 +127,7 @@ class ALSDALImpl[@specialized(Int, Long) ID: ClassTag]( data: RDD[Rating[ID]], }.toIterator } ret - }.setName("userFactors").cache() + }.setName("userFactors").cache().barrier().mapPartitions(iter => iter) val itemsFactorsRDD = results .mapPartitionsWithIndex { (index: Int, partiton: Iterator[ALSResult]) => diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/LinearRegressionDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/LinearRegressionDALImpl.scala index ee93892b2..067f0e3dd 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/LinearRegressionDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/LinearRegressionDALImpl.scala @@ -151,7 +151,7 @@ class LinearRegressionDALImpl( val fitIntercept: Boolean, } OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() // Make sure there is only one result from rank 0 assert(results.length == 1) diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/RandomForestRegressorDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/RandomForestRegressorDALImpl.scala index a4e9a0f78..1de452cf8 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/RandomForestRegressorDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/regression/RandomForestRegressorDALImpl.scala @@ -126,7 +126,7 @@ class RandomForestRegressorDALImpl(val uid: String, } ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() rfrTimer.record("Training") rfrTimer.print() diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala index 20362b896..96ea7a61b 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala @@ -94,7 +94,7 @@ class CorrelationDALImpl( } OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() corTimer.record("Training") corTimer.print() diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/SummarizerDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/SummarizerDALImpl.scala index 3f108364c..0b0d97520 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/SummarizerDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/SummarizerDALImpl.scala @@ -119,7 +119,7 @@ class SummarizerDALImpl(val executorNum: Int, } OneCCL.cleanup() ret - }.collect() + }.barrier().mapPartitions(iter => iter).collect() sumTimer.record("Training") sumTimer.print()