Skip to content

Commit

Permalink
chore: use NumPower
Browse files Browse the repository at this point in the history
  • Loading branch information
mcharytoniuk committed Feb 7, 2024
1 parent 57e1811 commit 350b964
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/Regressors/Ridge.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Exceptions\RuntimeException;
use NDArray as nd;

use function is_null;

Expand Down Expand Up @@ -60,6 +61,8 @@ class Ridge implements Estimator, Learner, RanksFeatures, Persistable
*/
protected ?\Tensor\Vector $coefficients = null;

protected ?nd $coefficientsNd = null;

Check failure on line 64 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Property Rubix\ML\Regressors\Ridge::$coefficientsNd has unknown class NDArray as its type.

Check failure on line 64 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Property Rubix\ML\Regressors\Ridge::$coefficientsNd has unknown class NDArray as its type.

Check failure on line 64 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Property Rubix\ML\Regressors\Ridge::$coefficientsNd has unknown class NDArray as its type.

/**
* @param float $l2Penalty
* @throws InvalidArgumentException
Expand Down Expand Up @@ -161,7 +164,7 @@ public function train(Dataset $dataset) : void
$biases = Matrix::ones($dataset->numSamples(), 1);

$x = Matrix::build($dataset->samples())->augmentLeft($biases);
$y = Vector::build($dataset->labels());
$y = nd::array($dataset->labels());

Check failure on line 167 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 167 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 167 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

/** @var int<0,max> $nHat */
$nHat = $x->n() - 1;
Expand All @@ -170,15 +173,18 @@ public function train(Dataset $dataset) : void

array_unshift($penalties, 0.0);

$penalties = Matrix::diagonal($penalties);
$penalties = nd::array(Matrix::diagonal($penalties)->asArray());

Check failure on line 176 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 176 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 176 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

$xNp = nd::array($x->asArray());

Check failure on line 178 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 178 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 178 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.
$xT = nd::transpose($xNp);

Check failure on line 179 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method transpose() on an unknown class NDArray.

Check failure on line 179 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method transpose() on an unknown class NDArray.

Check failure on line 179 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method transpose() on an unknown class NDArray.

$xT = $x->transpose();
$xMul = nd::matmul($xT, $xNp);

Check failure on line 181 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method matmul() on an unknown class NDArray.

Check failure on line 181 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method matmul() on an unknown class NDArray.

Check failure on line 181 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method matmul() on an unknown class NDArray.
$xMulAdd = nd::add($xMul, $penalties);

Check failure on line 182 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 182 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 182 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method add() on an unknown class NDArray.
$xMulAddInv = nd::inv($xMulAdd);

Check failure on line 183 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method inv() on an unknown class NDArray.

Check failure on line 183 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method inv() on an unknown class NDArray.

Check failure on line 183 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method inv() on an unknown class NDArray.
$xtDotY = nd::dot($xT, $y);

Check failure on line 184 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method dot() on an unknown class NDArray.

Check failure on line 184 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method dot() on an unknown class NDArray.

Check failure on line 184 in src/Regressors/Ridge.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method dot() on an unknown class NDArray.

$coefficients = $xT->matmul($x)
->add($penalties)
->inverse()
->dot($xT->dot($y))
->asArray();
$this->coefficientsNd = nd::dot($xMulAddInv, $xtDotY);
$coefficients = $this->coefficientsNd->toArray();

$this->bias = (float) array_shift($coefficients);
$this->coefficients = Vector::quick($coefficients);
Expand All @@ -199,10 +205,10 @@ public function predict(Dataset $dataset) : array

DatasetHasDimensionality::with($dataset, count($this->coefficients))->check();

return Matrix::build($dataset->samples())
->dot($this->coefficients)
->add($this->bias)
->asArray();
$datasetNd = nd::array($dataset->samples());
$datasetDotCoefficients = nd::dot($datasetNd, $this->coefficientsNd);

return nd::add($datasetDotCoefficients, $this->bias)->toArray();
}

/**
Expand Down

0 comments on commit 350b964

Please sign in to comment.