Skip to content

Commit

Permalink
Merge branch '2.5' into plus-plus-check
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Feb 8, 2024
2 parents 631981a + a354df5 commit 77a3504
Show file tree
Hide file tree
Showing 25 changed files with 486 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
with:
php-version: ${{ matrix.php-versions }}
tools: composer, pecl
extensions: svm, mbstring, gd, fileinfo
extensions: svm, mbstring, gd, fileinfo, swoole
ini-values: memory_limit=-1

- name: Validate composer.json
Expand Down
10 changes: 9 additions & 1 deletion benchmarks/Classifiers/OneVsRestBench.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

namespace Rubix\ML\Benchmarks\Classifiers;

use Rubix\ML\Backends\Backend;
use Rubix\ML\Classifiers\OneVsRest;
use Rubix\ML\Datasets\Generators\Blob;
use Rubix\ML\Classifiers\LogisticRegression;
use Rubix\ML\NeuralNet\Optimizers\Stochastic;
use Rubix\ML\Datasets\Generators\Agglomerate;
use Rubix\ML\Tests\DataProvider\BackendProviderTrait;

/**
* @Groups({"Classifiers"})
* @BeforeMethods({"setUp"})
*/
class OneVsRestBench
{
use BackendProviderTrait;

protected const TRAINING_SIZE = 10000;

protected const TESTING_SIZE = 10000;
Expand Down Expand Up @@ -52,9 +56,13 @@ public function setUp() : void
* @Subject
* @Iterations(5)
* @OutputTimeUnit("seconds", precision=3)
* @ParamProviders("provideBackends")
* @param array{ backend: Backend } $params
*/
public function trainPredict() : void
public function trainPredict(array $params) : void
{
$this->estimator->setBackend($params['backend']);

$this->estimator->train($this->training);

$this->estimator->predict($this->testing);
Expand Down
16 changes: 14 additions & 2 deletions benchmarks/Classifiers/RandomForestBench.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@

namespace Rubix\ML\Benchmarks\Classifiers;

use Rubix\ML\Backends\Backend;
use Rubix\ML\Classifiers\RandomForest;
use Rubix\ML\Datasets\Generators\Blob;
use Rubix\ML\Classifiers\ClassificationTree;
use Rubix\ML\Datasets\Generators\Agglomerate;
use Rubix\ML\Tests\DataProvider\BackendProviderTrait;
use Rubix\ML\Transformers\IntervalDiscretizer;

/**
* @Groups({"Classifiers"})
*/
class RandomForestBench
{
use BackendProviderTrait;

protected const TRAINING_SIZE = 10000;

protected const TESTING_SIZE = 10000;
Expand Down Expand Up @@ -70,9 +74,13 @@ public function setUpCategorical() : void
* @Iterations(5)
* @BeforeMethods({"setUpContinuous"})
* @OutputTimeUnit("seconds", precision=3)
* @ParamProviders("provideBackends")
* @param array{ backend: Backend } $params
*/
public function continuous() : void
public function continuous(array $params) : void
{
$this->estimator->setBackend($params['backend']);

$this->estimator->train($this->training);

$this->estimator->predict($this->testing);
Expand All @@ -83,9 +91,13 @@ public function continuous() : void
* @Iterations(5)
* @BeforeMethods({"setUpCategorical"})
* @OutputTimeUnit("seconds", precision=3)
* @ParamProviders("provideBackends")
* @param array{ backend: Backend } $params
*/
public function categorical() : void
public function categorical(array $params) : void
{
$this->estimator->setBackend($params['backend']);

$this->estimator->train($this->training);

$this->estimator->predict($this->testing);
Expand Down
3 changes: 2 additions & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
"phpstan/extension-installer": "^1.0",
"phpstan/phpstan": "^1.0",
"phpstan/phpstan-phpunit": "^1.0",
"phpunit/phpunit": "^9.0"
"phpunit/phpunit": "^9.0",
"swoole/ide-helper": "^5.1"
},
"suggest": {
"ext-tensor": "For fast Matrix/Vector computing",
Expand Down
1 change: 1 addition & 0 deletions phpstan.neon
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ parameters:
- 'benchmarks'
excludePaths:
- src/Backends/Amp.php
- src/Backends/Swoole.php
15 changes: 14 additions & 1 deletion phpunit.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
<?xml version="1.0" encoding="UTF-8"?>
<phpunit xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" backupGlobals="false" backupStaticAttributes="false" bootstrap="vendor/autoload.php" colors="true" convertErrorsToExceptions="true" convertNoticesToExceptions="true" convertWarningsToExceptions="true" forceCoversAnnotation="true" processIsolation="false" stopOnFailure="false" xsi:noNamespaceSchemaLocation="https://schema.phpunit.de/9.3/phpunit.xsd">
<phpunit
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
backupGlobals="false"
backupStaticAttributes="false"
bootstrap="vendor/autoload.php"
colors="true"
convertErrorsToExceptions="true"
convertNoticesToExceptions="true"
convertWarningsToExceptions="true"
forceCoversAnnotation="true"
processIsolation="true"
stopOnFailure="false"
xsi:noNamespaceSchemaLocation="https://schema.phpunit.de/9.3/phpunit.xsd"
>
<coverage processUncoveredFiles="true">
<include>
<directory suffix=".php">src</directory>
Expand Down
173 changes: 173 additions & 0 deletions src/Backends/Swoole.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
<?php

namespace Rubix\ML\Backends;

use Rubix\ML\Backends\Tasks\Task;
use Rubix\ML\Specifications\ExtensionIsLoaded;
use Rubix\ML\Specifications\SwooleExtensionIsLoaded;
use RuntimeException;
use Swoole\Atomic;
use Swoole\Process;

use function Swoole\Coroutine\run;

/**
* Swoole
*
* Works both with Swoole and OpenSwoole.
*
* @category Machine Learning
* @package Rubix/ML
*/
class Swoole implements Backend
{
/**
* The queue of tasks to be processed in parallel.
*/
protected array $queue = [];

private int $cpus;

private int $hasIgbinary;

public function __construct()
{
SwooleExtensionIsLoaded::create()->check();

$this->cpus = swoole_cpu_num();
$this->hasIgbinary = ExtensionIsLoaded::with('igbinary')->passes();
}

/**
* Queue up a deferred task for backend processing.
*
* @internal
*
* @param Task $task
* @param callable(mixed,mixed):void $after
* @param mixed $context
*/
public function enqueue(Task $task, ?callable $after = null, $context = null) : void
{
$this->queue[] = function () use ($task, $after, $context) {
$result = $task();

if ($after) {
$after($result, $context);
}

return $result;
};
}

/**
* Process the queue and return the results.
*
* @internal
*
* @return mixed[]
*/
public function process() : array
{
$results = [];

$maxMessageLength = new Atomic(0);
$workerProcesses = [];

$currentCpu = 0;

foreach ($this->queue as $index => $queueItem) {
$workerProcess = new Process(
function (Process $worker) use ($maxMessageLength, $queueItem) {
$serialized = $this->serialize($queueItem());

$serializedLength = strlen($serialized);
$currentMaxSerializedLength = $maxMessageLength->get();

if ($serializedLength > $currentMaxSerializedLength) {
$maxMessageLength->set($serializedLength);
}

$worker->exportSocket()->send($serialized);
},
// redirect_stdin_and_stdout
false,
// pipe_type
SOCK_DGRAM,
// enable_coroutine
true,
);

$workerProcess->setAffinity([$currentCpu]);
$workerProcess->setBlocking(false);
$workerProcess->start();

$workerProcesses[$index] = $workerProcess;

$currentCpu = ($currentCpu + 1) % $this->cpus;
}

run(function () use ($maxMessageLength, &$results, $workerProcesses) {
foreach ($workerProcesses as $index => $workerProcess) {
$status = $workerProcess->wait();

if (0 !== $status['code']) {
throw new RuntimeException('Worker process exited with an error');
}

$socket = $workerProcess->exportSocket();

if ($socket->isClosed()) {
throw new RuntimeException('Coroutine socket is closed');
}

$maxMessageLengthValue = $maxMessageLength->get();

$receivedData = $socket->recv($maxMessageLengthValue);
$unserialized = $this->unserialize($receivedData);

$results[] = $unserialized;
}
});

return $results;
}

/**
* Flush the queue
*/
public function flush() : void
{
$this->queue = [];
}

private function serialize(mixed $data) : string
{
if ($this->hasIgbinary) {
return igbinary_serialize($data);
}

return serialize($data);
}

private function unserialize(string $serialized) : mixed
{
if ($this->hasIgbinary) {
return igbinary_unserialize($serialized);
}

return unserialize($serialized);
}

/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
{
return 'Swoole';
}
}
16 changes: 16 additions & 0 deletions src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,20 @@ public function __toString() : string
{
return 'Logistic Regression (' . Params::stringify($this->params()) . ')';
}

/**
* Without this method, causes errors with Swoole backend + Igbinary
* serialization.
*
* Can be removed if it's no longer the case.
*
* @internal
* @param array<string,mixed> $data
*/
public function __unserialize(array $data) : void
{
foreach ($data as $propertyName => $propertyValue) {
$this->{$propertyName} = $propertyValue;
}
}
}
20 changes: 20 additions & 0 deletions src/EstimatorWrapper.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php

namespace Rubix\ML;

/**
* Wrapper
*
* @category Machine Learning
* @package Rubix/ML
* @author Ronan Giron
*/
interface EstimatorWrapper extends Estimator
{
/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator;
}
2 changes: 1 addition & 1 deletion src/GridSearch.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class GridSearch implements Estimator, Learner, Parallel, Verbose, Persistable
class GridSearch implements EstimatorWrapper, Learner, Parallel, Verbose, Persistable
{
use AutotrackRevisions, Multiprocessing, LoggerAware;

Expand Down
2 changes: 1 addition & 1 deletion src/PersistentModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class PersistentModel implements Estimator, Learner, Probabilistic, Scoring
class PersistentModel implements EstimatorWrapper, Learner, Probabilistic, Scoring
{
/**
* The persistable base learner.
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class Pipeline implements Online, Probabilistic, Scoring, Persistable
class Pipeline implements Online, Probabilistic, Scoring, Persistable, EstimatorWrapper
{
use AutotrackRevisions;

Expand Down
Loading

0 comments on commit 77a3504

Please sign in to comment.