From e24a250e40cbb59e25656a2749deef517469c76d Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Fri, 16 Aug 2024 11:43:30 +0800 Subject: [PATCH] fix lr bugs --- .../ml/regression/spark313/LinearRegression.scala | 12 +++++++++++- .../ml/regression/spark333/LinearRegression.scala | 11 ++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala index 6a8d1051b..84737560c 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala @@ -455,7 +455,14 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + + val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -486,6 +493,9 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala index a921dfbfc..ef71ee0d8 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala @@ -454,7 +454,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -485,6 +491,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the