Skip to content

Commit

Permalink
Fix test and appease Stan
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Nov 9, 2024
1 parent 715b0e7 commit 18acb0b
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- 2.5.1
- Fix bug in SVM (SVC and SVR) inferencing

- 2.5.0
- Added Vantage Point Spatial tree
- Blob Generator can now `simulate()` a Dataset object
Expand Down
1 change: 0 additions & 1 deletion src/AnomalyDetectors/OneClassSVM.php
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ public function __construct(
new ExtensionIsLoaded('svm'),
new ExtensionMinimumVersion('svm', '0.2.0'),
])->check();


if ($nu < 0.0 or $nu > 1.0) {
throw new InvalidArgumentException('Nu must be between'
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ public function train(Dataset $dataset) : void
* Make predictions from a dataset.
*
* @param Dataset $dataset
* @return list<string>
* @return list<int|string>
*/
public function predict(Dataset $dataset) : array
{
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Agglomerate.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public function __construct(array $generators = [], ?array $weights = null)
}

$this->generators = $generators;
$this->weights = array_combine(array_keys($generators), $weights) ?: [];
$this->weights = array_combine(array_keys($generators), $weights);

Check failure on line 107 in src/Datasets/Generators/Agglomerate.php

View workflow job for this annotation

GitHub Actions / PHP 7.4 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Agglomerate::$weights (array<float>) does not accept array<float|int>|false.
$this->dimensions = $dimensions;
}

Expand Down
10 changes: 5 additions & 5 deletions tests/AnomalyDetectors/OneClassSVMTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use Rubix\ML\Estimator;
use Rubix\ML\EstimatorType;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Kernels\SVM\Polynomial;
use Rubix\ML\Kernels\SVM\RBF;
use Rubix\ML\Datasets\Generators\Blob;
use Rubix\ML\Datasets\Generators\Circle;
use Rubix\ML\AnomalyDetectors\OneClassSVM;
Expand Down Expand Up @@ -43,7 +43,7 @@ class OneClassSVMTest extends TestCase
*
* @var float
*/
protected const MIN_SCORE = 0.5;
protected const MIN_SCORE = 0.7;

/**
* Constant used to see the random number generator.
Expand Down Expand Up @@ -77,7 +77,7 @@ protected function setUp() : void
1 => new Circle(0.0, 0.0, 8.0, 1.0),
], [0.9, 0.1]);

$this->estimator = new OneClassSVM(0.01, new Polynomial(4, 1e-3), true, 1e-4);
$this->estimator = new OneClassSVM(0.3, new RBF(), true, 1e-4);

$this->metric = new FBeta();

Expand Down Expand Up @@ -125,8 +125,8 @@ public function compatibility() : void
public function params() : void
{
$expected = [
'nu' => 0.01,
'kernel' => new Polynomial(4, 1e-3),
'nu' => 0.3,
'kernel' => new RBF(),
'shrinking' => true,
'tolerance' => 0.0001,
'cache size' => 100.0,
Expand Down
2 changes: 1 addition & 1 deletion tests/NeuralNet/FeedForwardTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public function build() : void
*/
public function layers() : void
{
$this->assertCount(7, $this->network->layers());
$this->assertCount(5, iterator_to_array($this->network->layers()));
}

/**
Expand Down

0 comments on commit 18acb0b

Please sign in to comment.