Skip to content

Commit

Permalink
Allow normalizers to skip NaN values (#333)
Browse files Browse the repository at this point in the history
* Add is_finite checks to skip NAN values

* Tests and fix issue with samples that are all non-finite
  • Loading branch information
27pchrisl authored Dec 26, 2024
1 parent 37730cb commit 646b1a2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/Transformers/MaxAbsoluteScaler.php
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public function update(Dataset $dataset) : void
foreach ($this->maxabs as $column => $oldMax) {
$values = $dataset->feature($column);

$max = max(array_map('abs', $values));
$max = max(array_map('abs', array_filter($values, 'is_finite') ?: [0]));

$max = max($oldMax, $max);

Expand Down
12 changes: 10 additions & 2 deletions src/Transformers/MinMaxNormalizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ public function fit(Dataset $dataset) : void
$values = $dataset->feature($column);

/** @var int|float $min */
$min = min($values);
$min = min(array_filter($values, 'is_finite') ?: [0]);

/** @var int|float $max */
$max = max($values);
$max = max(array_filter($values, 'is_finite') ?: [0]);

$scale = ($this->max - $this->min) / (($max - $min) ?: EPSILON);

Expand Down Expand Up @@ -199,6 +199,10 @@ public function transform(array &$samples) : void
foreach ($this->scales as $column => $scale) {
$value = &$sample[$column];

if (!is_finite($value)) {
continue;
}

$min = $this->minimums[$column];

$value *= $scale;
Expand All @@ -224,6 +228,10 @@ public function reverseTransform(array &$samples) : void
foreach ($this->scales as $column => $scale) {
$value = &$sample[$column];

if (!is_finite($value)) {
continue;
}

$min = $this->minimums[$column];

$value -= $this->min - $min * $scale;
Expand Down
11 changes: 11 additions & 0 deletions tests/Transformers/MaxAbsoluteScalerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,15 @@ public function reverseTransformUnfitted() : void

$this->transformer->reverseTransform($samples);
}

/**
* @test
*/
public function skipsNonFinite(): void
{
$samples = Unlabeled::build([[0.0, 3000.0, NAN, -6.0], [1.0, 30.0, NAN, 0.001]]);
$this->transformer->fit($samples);
$this->assertNan($samples[0][2]);
$this->assertNan($samples[1][2]);
}
}
11 changes: 11 additions & 0 deletions tests/Transformers/MinMaxNormalizerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,15 @@ public function transformUnfitted() : void

$this->transformer->transform($samples);
}

/**
* @test
*/
public function skipsNonFinite(): void
{
$samples = Unlabeled::build([[0.0, 3000.0, NAN, -6.0], [1.0, 30.0, NAN, 0.001]]);
$this->transformer->fit($samples);
$this->assertNan($samples[0][2]);
$this->assertNan($samples[1][2]);
}
}

0 comments on commit 646b1a2

Please sign in to comment.