Skip to content

Commit

Permalink
Wrapper interface (#314)
Browse files Browse the repository at this point in the history
* Add Wrapper interface for models wrappers

* Add WrapperAware trait

* Fix PhpDoc

* Revert "Add WrapperAware trait"

This reverts commit 241abc4.

* Rename Wrapper interface to EstimatorWrapper

* PHP CS fix
  • Loading branch information
ElGigi authored Jan 26, 2024
1 parent 865e54a commit 4c16268
Show file tree
Hide file tree
Showing 84 changed files with 165 additions and 145 deletions.
2 changes: 1 addition & 1 deletion src/AnomalyDetectors/LocalOutlierFactor.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class LocalOutlierFactor implements Estimator, Learner, Scoring, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The precomputed k distances between each training sample and its k'th nearest neighbor.
Expand Down
2 changes: 1 addition & 1 deletion src/AnomalyDetectors/Loda.php
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Loda implements Estimator, Learner, Online, Scoring, Persistable
*
* @var \Tensor\Matrix|null
*/
protected ?\Tensor\Matrix $r = null;
protected ?Matrix $r = null;

/**
* The edges and bin counts of each histogram.
Expand Down
4 changes: 2 additions & 2 deletions src/AnomalyDetectors/OneClassSVM.php
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class OneClassSVM implements Estimator, Learner
*
* @var svm
*/
protected \svm $svm;
protected svm $svm;

/**
* The hyper-parameters of the model.
Expand All @@ -58,7 +58,7 @@ class OneClassSVM implements Estimator, Learner
*
* @var \svmmodel|null
*/
protected ?\svmmodel $model = null;
protected ?svmmodel $model = null;

/**
* @param float $nu
Expand Down
2 changes: 1 addition & 1 deletion src/BootstrapAggregator.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BootstrapAggregator implements Estimator, Learner, Parallel, Persistable
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The number of base learners to train in the ensemble.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class AdaBoost implements Estimator, Learner, Probabilistic, Verbose, Persistabl
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The learning rate of the ensemble i.e. the *shrinkage* applied to each step.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/KDNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class KDNeighbors implements Estimator, Learner, Probabilistic, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The zero vector for the possible class outcomes.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/KNearestNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KNearestNeighbors implements Estimator, Learner, Online, Probabilistic, Pe
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The zero vector for the possible class outcomes.
Expand Down
6 changes: 3 additions & 3 deletions src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -103,14 +103,14 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
4 changes: 2 additions & 2 deletions src/Classifiers/LogitBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve
*
* @var Learner
*/
protected \Rubix\ML\Learner $booster;
protected Learner $booster;

/**
* The learning rate of the ensemble i.e. the *shrinkage* applied to each step.
Expand Down Expand Up @@ -138,7 +138,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve
*
* @var Metric
*/
protected \Rubix\ML\CrossValidation\Metrics\Metric $metric;
protected Metric $metric;

/**
* The ensemble of boosters.
Expand Down
8 changes: 4 additions & 4 deletions src/Classifiers/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic,
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -127,21 +127,21 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic,
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The validation metric used to score the generalization performance of the model during training.
*
* @var Metric
*/
protected \Rubix\ML\CrossValidation\Metrics\Metric $metric;
protected Metric $metric;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/OneVsRest.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class OneVsRest implements Estimator, Learner, Probabilistic, Parallel, Persista
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* A map of each class to its binary classifier.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/RadiusNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RadiusNeighbors implements Estimator, Learner, Probabilistic, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The class label for any samples that have 0 neighbors within the specified radius.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/RandomForest.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RandomForest implements Estimator, Learner, Probabilistic, Parallel, Ranks
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The number of learners to train in the ensemble.
Expand Down
6 changes: 3 additions & 3 deletions src/Classifiers/SoftmaxClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -99,14 +99,14 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/DBSCAN.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DBSCAN implements Estimator
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* @param float $radius
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/FuzzyCMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ class FuzzyCMeans implements Estimator, Learner, Probabilistic, Verbose, Persist
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/GaussianMixture.php
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GaussianMixture implements Estimator, Learner, Probabilistic, Verbose, Per
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The precomputed log prior probabilities of each cluster.
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/KMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ class KMeans implements Estimator, Learner, Online, Probabilistic, Verbose, Pers
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/MeanShift.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ class MeanShift implements Estimator, Learner, Probabilistic, Verbose, Persistab
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/Seeders/KMC2.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class KMC2 implements Seeder
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* @param int $m
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/Seeders/PlusPlus.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PlusPlus implements Seeder
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* @param \Rubix\ML\Kernels\Distance\Distance|null $kernel
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Blob implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The standard deviation of the blob.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Circle.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Circle implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the circle.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/HalfMoon.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HalfMoon implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the half moon.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Hyperplane.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Hyperplane implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $coefficients;
protected Vector $coefficients;

/**
* The y intercept term.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/SwissRoll.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SwissRoll implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the swiss roll.
Expand Down
20 changes: 20 additions & 0 deletions src/EstimatorWrapper.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php

namespace Rubix\ML;

/**
* Wrapper
*
* @category Machine Learning
* @package Rubix/ML
* @author Ronan Giron
*/
interface EstimatorWrapper extends Estimator
{
/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator;
}
2 changes: 1 addition & 1 deletion src/Extractors/SQLTable.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SQLTable implements Extractor
*
* @var PDO
*/
protected \PDO $connection;
protected PDO $connection;

/**
* The name of the table to select from.
Expand Down
2 changes: 1 addition & 1 deletion src/Graph/Nodes/Clique.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Clique implements Hypersphere, BinaryNode
*
* @var Labeled
*/
protected \Rubix\ML\Datasets\Labeled $dataset;
protected Labeled $dataset;

/**
* The centroid or multivariate mean of the cluster.
Expand Down
2 changes: 1 addition & 1 deletion src/Graph/Nodes/Neighborhood.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Neighborhood implements Hypercube, BinaryNode
*
* @var Labeled
*/
protected \Rubix\ML\Datasets\Labeled $dataset;
protected Labeled $dataset;

/**
* The multivariate minimum of the bounding box.
Expand Down
4 changes: 2 additions & 2 deletions src/Graph/Nodes/Traits/HasBinaryChildrenTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ trait HasBinaryChildrenTrait
*
* @var \Rubix\ML\Graph\Nodes\BinaryNode|null
*/
protected ?\Rubix\ML\Graph\Nodes\BinaryNode $left = null;
protected ?BinaryNode $left = null;

/**
* The right child node.
*
* @var \Rubix\ML\Graph\Nodes\BinaryNode|null
*/
protected ?\Rubix\ML\Graph\Nodes\BinaryNode $right = null;
protected ?BinaryNode $right = null;

/**
* Return the children of this node in a generator.
Expand Down
Loading

0 comments on commit 4c16268

Please sign in to comment.