Skip to content

Commit

Permalink
fix lr bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Aug 16, 2024
1 parent 5d08ac0 commit e24a250
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,14 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String)
private def trainWithNormal(
dataset: Dataset[_],
instr: Instrumentation): LinearRegressionModel = {
val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty)
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)

if (handlePersistence) {
dataset.persist(StorageLevel.MEMORY_AND_DISK)
dataset.count()
}

val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0)
val sparkContext = dataset.sparkSession.sparkContext
val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice)
val isPlatformSupported = Utils.checkClusterPlatformCompatibility(
Expand Down Expand Up @@ -486,6 +493,9 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String)
model.diagInvAtWA.toArray,
model.objectiveHistory)

if (handlePersistence) {
dataset.unpersist()
}
return lrModel.setSummary(Some(trainingSummary))
} else {
// For low dimensional data, WeightedLeastSquares is more efficient since the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
private def trainWithNormal(
dataset: Dataset[_],
instr: Instrumentation): LinearRegressionModel = {
val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty)
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)

if (handlePersistence) {
dataset.persist(StorageLevel.MEMORY_AND_DISK)
dataset.count()
}
val paramSupported = ($(regParam) == 0) || ($(regParam) != 0 && $(elasticNetParam) == 0)
val sparkContext = dataset.sparkSession.sparkContext
val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice)
val isPlatformSupported = Utils.checkClusterPlatformCompatibility(
Expand Down Expand Up @@ -485,6 +491,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model.diagInvAtWA.toArray,
model.objectiveHistory)

if (handlePersistence) {
dataset.unpersist()
}
return lrModel.setSummary(Some(trainingSummary))
} else {
// For low dimensional data, WeightedLeastSquares is more efficient since the
Expand Down

0 comments on commit e24a250

Please sign in to comment.