-
-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
use Rubix\ML\Datasets\Dataset; | ||
use Rubix\ML\Datasets\Unlabeled; | ||
use Rubix\ML\Exceptions\InvalidArgumentException; | ||
use NDArray as nd; | ||
|
||
use function count; | ||
use function sqrt; | ||
|
@@ -32,14 +33,12 @@ class Blob implements Generator | |
* | ||
* @var Vector | ||
*/ | ||
protected \Tensor\Vector $center; | ||
protected nd $center; | ||
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
Check failure on line 36 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.2 on ubuntu-latest
|
||
|
||
/** | ||
* The standard deviation of the blob. | ||
* | ||
* @var \Tensor\Vector|int|float | ||
*/ | ||
protected $stdDev; | ||
protected int|float|nd $stdDev; | ||
Check failure on line 41 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 41 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 41 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
|
||
/** | ||
* Fit a Blob generator to the samples in a dataset. | ||
|
@@ -94,15 +93,15 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) | |
} | ||
} | ||
|
||
$stdDev = Vector::quick($stdDev); | ||
$stdDev = nd::array($stdDev); | ||
Check failure on line 96 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 96 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 96 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
} else { | ||
if ($stdDev < 0) { | ||
throw new InvalidArgumentException('Standard deviation' | ||
. " must be greater than 0, $stdDev given."); | ||
} | ||
} | ||
|
||
$this->center = Vector::quick($center); | ||
$this->center = nd::array($center); | ||
Check failure on line 104 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 104 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 104 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
$this->stdDev = $stdDev; | ||
} | ||
|
||
|
@@ -113,7 +112,7 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) | |
*/ | ||
public function center() : array | ||
{ | ||
return $this->center->asArray(); | ||
return $this->center->toArray(); | ||
Check failure on line 115 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 115 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 115 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
} | ||
|
||
/** | ||
|
@@ -125,7 +124,7 @@ public function center() : array | |
*/ | ||
public function dimensions() : int | ||
{ | ||
return $this->center->n(); | ||
return $this->center->size(); | ||
Check failure on line 127 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 127 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 127 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
} | ||
|
||
/** | ||
|
@@ -138,11 +137,10 @@ public function generate(int $n) : Unlabeled | |
{ | ||
$d = $this->dimensions(); | ||
|
||
$samples = Matrix::gaussian($n, $d) | ||
->multiply($this->stdDev) | ||
->add($this->center) | ||
->asArray(); | ||
$samples = nd::normal([$n, $d]); | ||
Check failure on line 140 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 140 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 140 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
$samplesMul = nd::multiply($samples, $this->stdDev); | ||
Check failure on line 141 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 141 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 141 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
$samplesMulAddCenter = nd::add($samplesMul, $this->center); | ||
Check failure on line 142 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on windows-latest
Check failure on line 142 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.1 on ubuntu-latest
Check failure on line 142 in src/Datasets/Generators/Blob.php GitHub Actions / PHP 8.0 on ubuntu-latest
|
||
|
||
return Unlabeled::quick($samples); | ||
return Unlabeled::quick($samplesMulAddCenter->toArray()); | ||
} | ||
} |