From 4c16268996eb78073cf6ac628f6824655d744754 Mon Sep 17 00:00:00 2001 From: Ronan Giron Date: Fri, 26 Jan 2024 03:28:57 +0100 Subject: [PATCH] Wrapper interface (#314) * Add Wrapper interface for models wrappers * Add WrapperAware trait * Fix PhpDoc * Revert "Add WrapperAware trait" This reverts commit 241abc4317eec701211b7a88a17a1b610c366dfe. * Rename Wrapper interface to EstimatorWrapper * PHP CS fix --- src/AnomalyDetectors/LocalOutlierFactor.php | 2 +- src/AnomalyDetectors/Loda.php | 2 +- src/AnomalyDetectors/OneClassSVM.php | 4 ++-- src/BootstrapAggregator.php | 2 +- src/Classifiers/AdaBoost.php | 2 +- src/Classifiers/KDNeighbors.php | 2 +- src/Classifiers/KNearestNeighbors.php | 2 +- src/Classifiers/LogisticRegression.php | 6 +++--- src/Classifiers/LogitBoost.php | 4 ++-- src/Classifiers/MultilayerPerceptron.php | 8 ++++---- src/Classifiers/OneVsRest.php | 2 +- src/Classifiers/RadiusNeighbors.php | 2 +- src/Classifiers/RandomForest.php | 2 +- src/Classifiers/SoftmaxClassifier.php | 6 +++--- src/Clusterers/DBSCAN.php | 2 +- src/Clusterers/FuzzyCMeans.php | 4 ++-- src/Clusterers/GaussianMixture.php | 2 +- src/Clusterers/KMeans.php | 4 ++-- src/Clusterers/MeanShift.php | 4 ++-- src/Clusterers/Seeders/KMC2.php | 2 +- src/Clusterers/Seeders/PlusPlus.php | 2 +- src/Datasets/Generators/Blob.php | 2 +- src/Datasets/Generators/Circle.php | 2 +- src/Datasets/Generators/HalfMoon.php | 2 +- src/Datasets/Generators/Hyperplane.php | 2 +- src/Datasets/Generators/SwissRoll.php | 2 +- src/EstimatorWrapper.php | 20 +++++++++++++++++++ src/Extractors/SQLTable.php | 2 +- src/Graph/Nodes/Clique.php | 2 +- src/Graph/Nodes/Neighborhood.php | 2 +- .../Nodes/Traits/HasBinaryChildrenTrait.php | 4 ++-- src/Graph/Trees/BallTree.php | 4 ++-- src/Graph/Trees/DecisionTree.php | 2 +- src/Graph/Trees/ITree.php | 2 +- src/Graph/Trees/KDTree.php | 4 ++-- src/GridSearch.php | 8 ++++---- src/NeuralNet/FeedForward.php | 6 +++--- src/NeuralNet/Layers/Activation.php | 6 +++--- src/NeuralNet/Layers/BatchNorm.php | 10 +++++----- src/NeuralNet/Layers/Binary.php | 8 ++++---- src/NeuralNet/Layers/Continuous.php | 4 ++-- src/NeuralNet/Layers/Dense.php | 10 +++++----- src/NeuralNet/Layers/Dropout.php | 2 +- src/NeuralNet/Layers/Multiclass.php | 8 ++++---- src/NeuralNet/Layers/PReLU.php | 6 +++--- src/NeuralNet/Layers/Swish.php | 10 +++++----- src/NeuralNet/Parameter.php | 2 +- src/PersistentModel.php | 8 ++++---- src/Pipeline.php | 4 ++-- src/Regressors/Adaline.php | 6 +++--- src/Regressors/GradientBoost.php | 4 ++-- src/Regressors/KDNeighborsRegressor.php | 2 +- src/Regressors/KNNRegressor.php | 2 +- src/Regressors/MLPRegressor.php | 8 ++++---- src/Regressors/RadiusNeighborsRegressor.php | 2 +- src/Regressors/Ridge.php | 2 +- src/Regressors/SVR.php | 4 ++-- src/Serializers/GzipNative.php | 2 +- src/Serializers/RBX.php | 2 +- .../DatasetHasDimensionality.php | 2 +- src/Specifications/DatasetIsLabeled.php | 2 +- src/Specifications/DatasetIsNotEmpty.php | 2 +- .../EstimatorIsCompatibleWithMetric.php | 4 ++-- .../LabelsAreCompatibleWithLearner.php | 4 ++-- .../SamplesAreCompatibleWithDistance.php | 4 ++-- .../SamplesAreCompatibleWithEstimator.php | 4 ++-- .../SamplesAreCompatibleWithTransformer.php | 4 ++-- src/Tokenizers/KSkipNGram.php | 4 ++-- src/Tokenizers/NGram.php | 4 ++-- src/Traits/LoggerAware.php | 2 +- src/Traits/Multiprocessing.php | 2 +- src/Transformers/GaussianRandomProjector.php | 2 +- src/Transformers/HotDeckImputer.php | 2 +- src/Transformers/KNNImputer.php | 2 +- .../LinearDiscriminantAnalysis.php | 2 +- src/Transformers/MissingDataImputer.php | 4 ++-- .../PrincipalComponentAnalysis.php | 2 +- src/Transformers/TSNE.php | 2 +- src/Transformers/TokenHashingVectorizer.php | 2 +- src/Transformers/TruncatedSVD.php | 2 +- src/Transformers/WordCountVectorizer.php | 2 +- tests/Graph/Nodes/NeighborhoodTest.php | 2 +- tests/NeuralNet/ParameterTest.php | 2 +- tests/Transformers/ImageRotatorTest.php | 2 +- 84 files changed, 165 insertions(+), 145 deletions(-) create mode 100644 src/EstimatorWrapper.php diff --git a/src/AnomalyDetectors/LocalOutlierFactor.php b/src/AnomalyDetectors/LocalOutlierFactor.php index 4ebda0f00..5c798b4b6 100644 --- a/src/AnomalyDetectors/LocalOutlierFactor.php +++ b/src/AnomalyDetectors/LocalOutlierFactor.php @@ -67,7 +67,7 @@ class LocalOutlierFactor implements Estimator, Learner, Scoring, Persistable * * @var Spatial */ - protected \Rubix\ML\Graph\Trees\Spatial $tree; + protected Spatial $tree; /** * The precomputed k distances between each training sample and its k'th nearest neighbor. diff --git a/src/AnomalyDetectors/Loda.php b/src/AnomalyDetectors/Loda.php index f06b9a4a7..cf76762e3 100644 --- a/src/AnomalyDetectors/Loda.php +++ b/src/AnomalyDetectors/Loda.php @@ -100,7 +100,7 @@ class Loda implements Estimator, Learner, Online, Scoring, Persistable * * @var \Tensor\Matrix|null */ - protected ?\Tensor\Matrix $r = null; + protected ?Matrix $r = null; /** * The edges and bin counts of each histogram. diff --git a/src/AnomalyDetectors/OneClassSVM.php b/src/AnomalyDetectors/OneClassSVM.php index faf1111be..10969ab6f 100644 --- a/src/AnomalyDetectors/OneClassSVM.php +++ b/src/AnomalyDetectors/OneClassSVM.php @@ -44,7 +44,7 @@ class OneClassSVM implements Estimator, Learner * * @var svm */ - protected \svm $svm; + protected svm $svm; /** * The hyper-parameters of the model. @@ -58,7 +58,7 @@ class OneClassSVM implements Estimator, Learner * * @var \svmmodel|null */ - protected ?\svmmodel $model = null; + protected ?svmmodel $model = null; /** * @param float $nu diff --git a/src/BootstrapAggregator.php b/src/BootstrapAggregator.php index 742dcc6ab..fc30cc882 100644 --- a/src/BootstrapAggregator.php +++ b/src/BootstrapAggregator.php @@ -64,7 +64,7 @@ class BootstrapAggregator implements Estimator, Learner, Parallel, Persistable * * @var Learner */ - protected \Rubix\ML\Learner $base; + protected Learner $base; /** * The number of base learners to train in the ensemble. diff --git a/src/Classifiers/AdaBoost.php b/src/Classifiers/AdaBoost.php index 4f428c5f4..9c74fbd7a 100644 --- a/src/Classifiers/AdaBoost.php +++ b/src/Classifiers/AdaBoost.php @@ -72,7 +72,7 @@ class AdaBoost implements Estimator, Learner, Probabilistic, Verbose, Persistabl * * @var Learner */ - protected \Rubix\ML\Learner $base; + protected Learner $base; /** * The learning rate of the ensemble i.e. the *shrinkage* applied to each step. diff --git a/src/Classifiers/KDNeighbors.php b/src/Classifiers/KDNeighbors.php index 642b64e16..4ef9f5864 100644 --- a/src/Classifiers/KDNeighbors.php +++ b/src/Classifiers/KDNeighbors.php @@ -60,7 +60,7 @@ class KDNeighbors implements Estimator, Learner, Probabilistic, Persistable * * @var Spatial */ - protected \Rubix\ML\Graph\Trees\Spatial $tree; + protected Spatial $tree; /** * The zero vector for the possible class outcomes. diff --git a/src/Classifiers/KNearestNeighbors.php b/src/Classifiers/KNearestNeighbors.php index d5293c932..ee5c4b39b 100644 --- a/src/Classifiers/KNearestNeighbors.php +++ b/src/Classifiers/KNearestNeighbors.php @@ -62,7 +62,7 @@ class KNearestNeighbors implements Estimator, Learner, Online, Probabilistic, Pe * * @var Distance */ - protected \Rubix\ML\Kernels\Distance\Distance $kernel; + protected Distance $kernel; /** * The zero vector for the possible class outcomes. diff --git a/src/Classifiers/LogisticRegression.php b/src/Classifiers/LogisticRegression.php index ba67b5f57..b48ea7239 100644 --- a/src/Classifiers/LogisticRegression.php +++ b/src/Classifiers/LogisticRegression.php @@ -67,7 +67,7 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R * * @var Optimizer */ - protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer; + protected Optimizer $optimizer; /** * The amount of L2 regularization applied to the weights of the output layer. @@ -103,14 +103,14 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R * * @var ClassificationLoss */ - protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn; + protected ClassificationLoss $costFn; /** * The underlying neural network instance. * * @var \Rubix\ML\NeuralNet\FeedForward|null */ - protected ?\Rubix\ML\NeuralNet\FeedForward $network = null; + protected ?FeedForward $network = null; /** * The unique class labels. diff --git a/src/Classifiers/LogitBoost.php b/src/Classifiers/LogitBoost.php index 071dfed39..f12588cfb 100644 --- a/src/Classifiers/LogitBoost.php +++ b/src/Classifiers/LogitBoost.php @@ -89,7 +89,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve * * @var Learner */ - protected \Rubix\ML\Learner $booster; + protected Learner $booster; /** * The learning rate of the ensemble i.e. the *shrinkage* applied to each step. @@ -138,7 +138,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve * * @var Metric */ - protected \Rubix\ML\CrossValidation\Metrics\Metric $metric; + protected Metric $metric; /** * The ensemble of boosters. diff --git a/src/Classifiers/MultilayerPerceptron.php b/src/Classifiers/MultilayerPerceptron.php index 1018d10c4..233c8b1eb 100644 --- a/src/Classifiers/MultilayerPerceptron.php +++ b/src/Classifiers/MultilayerPerceptron.php @@ -85,7 +85,7 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic, * * @var Optimizer */ - protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer; + protected Optimizer $optimizer; /** * The amount of L2 regularization applied to the weights of the output layer. @@ -127,21 +127,21 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic, * * @var ClassificationLoss */ - protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn; + protected ClassificationLoss $costFn; /** * The validation metric used to score the generalization performance of the model during training. * * @var Metric */ - protected \Rubix\ML\CrossValidation\Metrics\Metric $metric; + protected Metric $metric; /** * The underlying neural network instance. * * @var \Rubix\ML\NeuralNet\FeedForward|null */ - protected ?\Rubix\ML\NeuralNet\FeedForward $network = null; + protected ?FeedForward $network = null; /** * The unique class labels. diff --git a/src/Classifiers/OneVsRest.php b/src/Classifiers/OneVsRest.php index 7c07e2627..841fb2751 100644 --- a/src/Classifiers/OneVsRest.php +++ b/src/Classifiers/OneVsRest.php @@ -51,7 +51,7 @@ class OneVsRest implements Estimator, Learner, Probabilistic, Parallel, Persista * * @var Learner */ - protected \Rubix\ML\Learner $base; + protected Learner $base; /** * A map of each class to its binary classifier. diff --git a/src/Classifiers/RadiusNeighbors.php b/src/Classifiers/RadiusNeighbors.php index b1e3544c9..1dc670186 100644 --- a/src/Classifiers/RadiusNeighbors.php +++ b/src/Classifiers/RadiusNeighbors.php @@ -60,7 +60,7 @@ class RadiusNeighbors implements Estimator, Learner, Probabilistic, Persistable * * @var Spatial */ - protected \Rubix\ML\Graph\Trees\Spatial $tree; + protected Spatial $tree; /** * The class label for any samples that have 0 neighbors within the specified radius. diff --git a/src/Classifiers/RandomForest.php b/src/Classifiers/RandomForest.php index 5d2b8d5cb..eb62f5e32 100644 --- a/src/Classifiers/RandomForest.php +++ b/src/Classifiers/RandomForest.php @@ -73,7 +73,7 @@ class RandomForest implements Estimator, Learner, Probabilistic, Parallel, Ranks * * @var Learner */ - protected \Rubix\ML\Learner $base; + protected Learner $base; /** * The number of learners to train in the ensemble. diff --git a/src/Classifiers/SoftmaxClassifier.php b/src/Classifiers/SoftmaxClassifier.php index 998035701..3038c04a3 100644 --- a/src/Classifiers/SoftmaxClassifier.php +++ b/src/Classifiers/SoftmaxClassifier.php @@ -64,7 +64,7 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve * * @var Optimizer */ - protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer; + protected Optimizer $optimizer; /** * The amount of L2 regularization applied to the weights of the output layer. @@ -99,14 +99,14 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve * * @var ClassificationLoss */ - protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn; + protected ClassificationLoss $costFn; /** * The underlying neural network instance. * * @var \Rubix\ML\NeuralNet\FeedForward|null */ - protected ?\Rubix\ML\NeuralNet\FeedForward $network = null; + protected ?FeedForward $network = null; /** * The unique class labels. diff --git a/src/Clusterers/DBSCAN.php b/src/Clusterers/DBSCAN.php index 44112e02a..c24546d37 100644 --- a/src/Clusterers/DBSCAN.php +++ b/src/Clusterers/DBSCAN.php @@ -73,7 +73,7 @@ class DBSCAN implements Estimator * * @var Spatial */ - protected \Rubix\ML\Graph\Trees\Spatial $tree; + protected Spatial $tree; /** * @param float $radius diff --git a/src/Clusterers/FuzzyCMeans.php b/src/Clusterers/FuzzyCMeans.php index 49d5716ff..dd3b27b84 100644 --- a/src/Clusterers/FuzzyCMeans.php +++ b/src/Clusterers/FuzzyCMeans.php @@ -92,14 +92,14 @@ class FuzzyCMeans implements Estimator, Learner, Probabilistic, Verbose, Persist * * @var Distance */ - protected \Rubix\ML\Kernels\Distance\Distance $kernel; + protected Distance $kernel; /** * The cluster centroid seeder. * * @var Seeder */ - protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder; + protected Seeder $seeder; /** * The computed centroid vectors of the training data. diff --git a/src/Clusterers/GaussianMixture.php b/src/Clusterers/GaussianMixture.php index 4445d0422..68338cfd9 100644 --- a/src/Clusterers/GaussianMixture.php +++ b/src/Clusterers/GaussianMixture.php @@ -97,7 +97,7 @@ class GaussianMixture implements Estimator, Learner, Probabilistic, Verbose, Per * * @var Seeder */ - protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder; + protected Seeder $seeder; /** * The precomputed log prior probabilities of each cluster. diff --git a/src/Clusterers/KMeans.php b/src/Clusterers/KMeans.php index 852b178c4..280e70922 100644 --- a/src/Clusterers/KMeans.php +++ b/src/Clusterers/KMeans.php @@ -96,14 +96,14 @@ class KMeans implements Estimator, Learner, Online, Probabilistic, Verbose, Pers * * @var Distance */ - protected \Rubix\ML\Kernels\Distance\Distance $kernel; + protected Distance $kernel; /** * The cluster centroid seeder. * * @var Seeder */ - protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder; + protected Seeder $seeder; /** * The computed centroid vectors of the training data. diff --git a/src/Clusterers/MeanShift.php b/src/Clusterers/MeanShift.php index 0d89ce00f..97af51353 100644 --- a/src/Clusterers/MeanShift.php +++ b/src/Clusterers/MeanShift.php @@ -104,14 +104,14 @@ class MeanShift implements Estimator, Learner, Probabilistic, Verbose, Persistab * * @var Spatial */ - protected \Rubix\ML\Graph\Trees\Spatial $tree; + protected Spatial $tree; /** * The cluster centroid seeder. * * @var Seeder */ - protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder; + protected Seeder $seeder; /** * The computed centroid vectors of the training data. diff --git a/src/Clusterers/Seeders/KMC2.php b/src/Clusterers/Seeders/KMC2.php index d4e155e5e..717a29426 100644 --- a/src/Clusterers/Seeders/KMC2.php +++ b/src/Clusterers/Seeders/KMC2.php @@ -39,7 +39,7 @@ class KMC2 implements Seeder * * @var Distance */ - protected \Rubix\ML\Kernels\Distance\Distance $kernel; + protected Distance $kernel; /** * @param int $m diff --git a/src/Clusterers/Seeders/PlusPlus.php b/src/Clusterers/Seeders/PlusPlus.php index ad4f82e24..4a59d98b4 100644 --- a/src/Clusterers/Seeders/PlusPlus.php +++ b/src/Clusterers/Seeders/PlusPlus.php @@ -32,7 +32,7 @@ class PlusPlus implements Seeder * * @var Distance */ - protected \Rubix\ML\Kernels\Distance\Distance $kernel; + protected Distance $kernel; /** * @param \Rubix\ML\Kernels\Distance\Distance|null $kernel diff --git a/src/Datasets/Generators/Blob.php b/src/Datasets/Generators/Blob.php index 994d6fa52..62f703ae6 100644 --- a/src/Datasets/Generators/Blob.php +++ b/src/Datasets/Generators/Blob.php @@ -32,7 +32,7 @@ class Blob implements Generator * * @var Vector */ - protected \Tensor\Vector $center; + protected Vector $center; /** * The standard deviation of the blob. diff --git a/src/Datasets/Generators/Circle.php b/src/Datasets/Generators/Circle.php index d0a5ee14c..aed785d65 100644 --- a/src/Datasets/Generators/Circle.php +++ b/src/Datasets/Generators/Circle.php @@ -27,7 +27,7 @@ class Circle implements Generator * * @var Vector */ - protected \Tensor\Vector $center; + protected Vector $center; /** * The scaling factor of the circle. diff --git a/src/Datasets/Generators/HalfMoon.php b/src/Datasets/Generators/HalfMoon.php index 26486240e..e41a4a265 100644 --- a/src/Datasets/Generators/HalfMoon.php +++ b/src/Datasets/Generators/HalfMoon.php @@ -26,7 +26,7 @@ class HalfMoon implements Generator * * @var Vector */ - protected \Tensor\Vector $center; + protected Vector $center; /** * The scaling factor of the half moon. diff --git a/src/Datasets/Generators/Hyperplane.php b/src/Datasets/Generators/Hyperplane.php index 8afa59934..a5ae532bc 100644 --- a/src/Datasets/Generators/Hyperplane.php +++ b/src/Datasets/Generators/Hyperplane.php @@ -27,7 +27,7 @@ class Hyperplane implements Generator * * @var Vector */ - protected \Tensor\Vector $coefficients; + protected Vector $coefficients; /** * The y intercept term. diff --git a/src/Datasets/Generators/SwissRoll.php b/src/Datasets/Generators/SwissRoll.php index 8cd017ffa..f0899a284 100644 --- a/src/Datasets/Generators/SwissRoll.php +++ b/src/Datasets/Generators/SwissRoll.php @@ -33,7 +33,7 @@ class SwissRoll implements Generator * * @var Vector */ - protected \Tensor\Vector $center; + protected Vector $center; /** * The scaling factor of the swiss roll. diff --git a/src/EstimatorWrapper.php b/src/EstimatorWrapper.php new file mode 100644 index 000000000..aafb3ac8e --- /dev/null +++ b/src/EstimatorWrapper.php @@ -0,0 +1,20 @@ +assertInstanceOf(NeighborHood::class, $node); + $this->assertInstanceOf(Neighborhood::class, $node); $this->assertInstanceOf(Labeled::class, $node->dataset()); $this->assertEquals(self::BOX, iterator_to_array($node->sides())); } diff --git a/tests/NeuralNet/ParameterTest.php b/tests/NeuralNet/ParameterTest.php index 80fc03603..ec54adc51 100644 --- a/tests/NeuralNet/ParameterTest.php +++ b/tests/NeuralNet/ParameterTest.php @@ -16,7 +16,7 @@ class ParameterTest extends TestCase /** * @var Parameter */ - protected \Rubix\ML\NeuralNet\Parameter $param; + protected Parameter $param; /** * @var \Rubix\ML\NeuralNet\Optimizers\Optimizer diff --git a/tests/Transformers/ImageRotatorTest.php b/tests/Transformers/ImageRotatorTest.php index 5a88c0082..31297da76 100644 --- a/tests/Transformers/ImageRotatorTest.php +++ b/tests/Transformers/ImageRotatorTest.php @@ -17,7 +17,7 @@ class RandomizedImageRotatorTest extends TestCase /** * @var ImageRotator */ - protected \Rubix\ML\Transformers\ImageRotator $transformer; + protected ImageRotator $transformer; /** * @before