diff --git a/src/Regressors/Ridge.php b/src/Regressors/Ridge.php index 329b3935c..e351e4736 100644 --- a/src/Regressors/Ridge.php +++ b/src/Regressors/Ridge.php @@ -21,6 +21,7 @@ use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator; use Rubix\ML\Exceptions\InvalidArgumentException; use Rubix\ML\Exceptions\RuntimeException; +use NDArray as nd; use function is_null; @@ -60,6 +61,8 @@ class Ridge implements Estimator, Learner, RanksFeatures, Persistable */ protected ?\Tensor\Vector $coefficients = null; + protected ?nd $coefficientsNd = null; + /** * @param float $l2Penalty * @throws InvalidArgumentException @@ -161,7 +164,7 @@ public function train(Dataset $dataset) : void $biases = Matrix::ones($dataset->numSamples(), 1); $x = Matrix::build($dataset->samples())->augmentLeft($biases); - $y = Vector::build($dataset->labels()); + $y = nd::array($dataset->labels()); /** @var int<0,max> $nHat */ $nHat = $x->n() - 1; @@ -170,15 +173,18 @@ public function train(Dataset $dataset) : void array_unshift($penalties, 0.0); - $penalties = Matrix::diagonal($penalties); + $penalties = nd::array(Matrix::diagonal($penalties)->asArray()); + + $xNp = nd::array($x->asArray()); + $xT = nd::transpose($xNp); - $xT = $x->transpose(); + $xMul = nd::matmul($xT, $xNp); + $xMulAdd = nd::add($xMul, $penalties); + $xMulAddInv = nd::inv($xMulAdd); + $xtDotY = nd::dot($xT, $y); - $coefficients = $xT->matmul($x) - ->add($penalties) - ->inverse() - ->dot($xT->dot($y)) - ->asArray(); + $this->coefficientsNd = nd::dot($xMulAddInv, $xtDotY); + $coefficients = $this->coefficientsNd->toArray(); $this->bias = (float) array_shift($coefficients); $this->coefficients = Vector::quick($coefficients); @@ -199,10 +205,10 @@ public function predict(Dataset $dataset) : array DatasetHasDimensionality::with($dataset, count($this->coefficients))->check(); - return Matrix::build($dataset->samples()) - ->dot($this->coefficients) - ->add($this->bias) - ->asArray(); + $datasetNd = nd::array($dataset->samples()); + $datasetDotCoefficients = nd::dot($datasetNd, $this->coefficientsNd); + + return nd::add($datasetDotCoefficients, $this->bias)->toArray(); } /**