-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
138 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
scala2/src/main/scala/jurisk/math/GaussianElimination.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package jurisk.math | ||
|
||
// https://en.wikipedia.org/wiki/Gaussian_elimination | ||
object GaussianElimination { | ||
def solve(A: Array[Array[Double]], b: Array[Double]): Array[Double] = { | ||
val n = A.length | ||
val augmented = A.zip(b).map { case (row, bi) => row :+ bi } | ||
|
||
for (col <- 0 until n) { | ||
// Find pivot row | ||
val pivotRow = (col until n).maxBy(row => math.abs(augmented(row)(col))) | ||
val temp = augmented(col) | ||
augmented(col) = augmented(pivotRow) | ||
augmented(pivotRow) = temp | ||
|
||
// Make leading coefficient of pivot row 1 | ||
val pivotElement = augmented(col)(col) | ||
for (j <- col until n + 1) | ||
augmented(col)(j) /= pivotElement | ||
|
||
// Eliminate below pivot | ||
for (i <- col + 1 until n) { | ||
val factor = augmented(i)(col) | ||
for (j <- col until n + 1) | ||
augmented(i)(j) -= factor * augmented(col)(j) | ||
} | ||
} | ||
|
||
// Back substitution | ||
val x = new Array[Double](n) | ||
for (i <- n - 1 to 0 by -1) { | ||
x(i) = augmented(i)(n) | ||
for (j <- i + 1 until n) | ||
x(i) -= augmented(i)(j) * x(j) | ||
} | ||
|
||
x | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
scala2/src/test/scala/jurisk/math/GaussianEliminationSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package jurisk.math | ||
|
||
import jurisk.math.GaussianElimination.solve | ||
import org.scalatest.freespec.AnyFreeSpec | ||
import org.scalatest.matchers.should.Matchers._ | ||
|
||
class GaussianEliminationSpec extends AnyFreeSpec { | ||
private def compare( | ||
solution: Array[Double], | ||
expected: Array[Double], | ||
): Unit = { | ||
val Eps = 1e-5 | ||
|
||
solution.length shouldEqual expected.length | ||
for ((v, i) <- solution.zipWithIndex) | ||
assert( | ||
math.abs(v - expected(i)) < Eps, | ||
s"Value at index $i should be close to ${expected(i)}", | ||
) | ||
} | ||
|
||
"GaussianElimination" - { | ||
"test 1" in { | ||
// 2x + y -z = 8 | ||
// -3x -y + 2z = -11 | ||
// -2x +y + 2z = -3 | ||
val A = Array( | ||
Array(2.0, 1.0, -1.0), | ||
Array(-3.0, -1.0, 2.0), | ||
Array(-2.0, 1.0, 2.0), | ||
) | ||
val b = Array(8.0, -11.0, -3.0) | ||
val expected = Array(2.0, 3.0, -1.0) | ||
val solution = solve(A, b) | ||
|
||
compare(solution, expected) | ||
} | ||
|
||
// https://en.wikipedia.org/wiki/Gaussian_elimination#Example_of_the_algorithm | ||
"test from Wikipedia" in { | ||
// 2x + y - z = 8 | ||
// -3x -y + 2z = -11 | ||
// -2x + y + 2z = -3 | ||
val A = Array( | ||
Array(2.0, 1.0, -1.0), | ||
Array(-3.0, -1.0, 2.0), | ||
Array(-2.0, 1.0, 2.0), | ||
) | ||
|
||
val b = Array(8.0, -11.0, -3.0) | ||
val expected = Array(2.0, 3.0, -1.0) | ||
val solution = solve(A, b) | ||
|
||
compare(solution, expected) | ||
} | ||
} | ||
} |