Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle IEEE Floating Point Special Values #105

Merged
merged 6 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions cbits/tape.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,23 @@ void tape_backPropagate(void* p, int start, double* out)
while (--idx >= 0)
{
double v = buffer[idx + pTape->offset];
if (v == 0.0) continue;

// TODO: if we do not care about handling IEEE floating point special values (NaN, Inf) correctly
// then we can skip the rest of the loop body in case v == 0
// see also https://github.com/ekmett/ad/issues/106

int i = pTape->lnk[idx*2];
double x = pTape->val[idx*2];
if (x != 0.0)
if (i >= 0)
{
buffer[i] += v*x;
double x = v * pTape->val[idx*2];
if (x != 0) buffer[i] += x;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see what you mean, i didn't realize this new x was the old x * v stuff

}

int j = pTape->lnk[idx*2 + 1];
double y = pTape->val[idx*2 + 1];
if (y != 0.0)
if (j >= 0)
{
buffer[j] += v*y;
double y = v * pTape->val[idx*2 + 1];
if (y != 0) buffer[j] += y;
}
}
idx += 1 + pTape->offset;
Expand All @@ -122,4 +125,4 @@ void tape_free(void* p)
pTape = pTape->prev;
free(p);
}
}
}
2 changes: 1 addition & 1 deletion src/Numeric/AD/Internal/Reverse/Double.hs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ reifyTypeableTape vs k = unsafePerformIO $ fmap (\t -> reifyTypeable t k) (newTa
-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily f di i b = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i 0 di 0.0)) $! f b
unarily f di i b = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i (-1) di 0.0)) $! f b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
Expand Down
109 changes: 100 additions & 9 deletions tests/Regression.hs
Original file line number Diff line number Diff line change
@@ -1,25 +1,116 @@
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes #-}

module Main (main) where

import qualified Numeric.AD.Mode.Reverse as R
import qualified Numeric.AD.Mode.Reverse.Double as RD

import Text.Printf
import Test.Tasty
import Test.Tasty.HUnit

type Diff = (forall a. Floating a => a -> a) -> Double -> Double
type Grad = (forall a. Floating a => [a] -> a) -> [Double] -> [Double]
type Jacobian = (forall a. Floating a => [a] -> [a]) -> [Double] -> [[Double]]
type Hessian = (forall a. Floating a => [a] -> a) -> [Double] -> [[Double]]

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests = testGroup "Regression tests"
[ testCase "#97" $
assertBool "Reverse.diff and Reverse.Double.diff should behave identically" $
nearZero $ R.diff f (0 :: Double) - RD.diff f (0 :: Double)
]
tests = testGroup "tests" [
mode "reverse" (\ f -> R.diff f) (\ f -> R.grad f) (\ f -> R.jacobian f) (\ f -> R.hessian f),
mode "reverse-double" (\ f -> RD.diff f) (\ f -> RD.grad f) (\ f -> RD.jacobian f) (\ f -> RD.hessian f)]

mode :: String -> Diff -> Grad -> Jacobian -> Hessian -> TestTree
mode name diff grad jacobian hessian = testGroup name [basic diff grad jacobian hessian, issue97 diff, issue104 diff grad]

basic :: Diff -> Grad -> Jacobian -> Hessian -> TestTree
basic diff grad jacobian hessian = testGroup "basic" [tdiff, tgrad, tjacobian, thessian] where
tdiff = testCase "diff" $ do
assertNearList [11, 5.5, 3, 3.5, 7, 13.5, 23, 35.5, 51] $ diff p <$> [-2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2]
assertNearList [nan, inf, 1, 0.5, 0.25] $ diff sqrt <$> [-1, 0, 0.25, 1, 4]
assertNearList [1, 0, 1] $ [diff sin, diff cos, diff tan] <*> [0]
assertNearList [-1, 0, 1] $ diff abs <$> [-1, 0, 1]
assertNearList [1, exp 1, inf, 1] $ [diff exp, diff log] <*> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [2, 1, 1] $ grad f [1, 2, 3]
assertNearList [1, 0.25] $ grad h [2, 8]
assertNearList [0, nan] $ grad power [0, 2]
tjacobian = testCase "jacobian" $ do
assertNearMatrix [[0, 1], [1, 0], [1, 2]] $ jacobian g [2, 1]
thessian = testCase "hessian" $ do
assertNearMatrix [[0, 1, 0], [1, 0, 0], [0, 0, 0]] $ hessian f [1, 2, 3]
assertNearMatrix [[0, 0], [0, 0]] $ hessian sum [1, 2]
assertNearMatrix [[0, 1], [1, 0]] $ hessian product [1, 2]
assertNearMatrix [[2, 1], [1, 0]] $ hessian power [1, 2]
sum = \ [x, y] -> x + y
product = \ [x, y] -> x * y
power = \ [x, y] -> x ** y
f = \ [x, y, z] -> x * y + z
g = \ [x, y] -> [y, x, x * y]
h = \ [x, y] -> sqrt $ x * y
p = \ x -> 12 + 7 * x + 5 * x ^ 2 + 2 * x ^ 3

-- Reverse.Double +ffi initializes the tape with a block of size 4096
-- The large term in this function forces the allocation of an additional block
f :: Num a => a -> a
f = sum . replicate 5000
issue97 :: Diff -> TestTree
issue97 diff = testCase "issue-97" $ assertNear 5000 $ diff f 0 where f = sum . replicate 5000

issue104 :: Diff -> Grad -> TestTree
issue104 diff grad = testGroup "issue-104" [inside, outside] where
inside = testGroup "inside" [tdiff, tgrad] where
tdiff = testCase "diff" $ do
assertNearList [nan, nan] $ diff (0 `f`) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (1 `f`) <$> [0, 1]
assertNearList [nan, nan] $ diff (`f` 0) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (`f` 1) <$> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [nan, nan] $ grad (binary f) [0, 0]
assertNearList [nan, inf] $ grad (binary f) [1, 0]
assertNearList [inf, nan] $ grad (binary f) [0, 1]
assertNearList [0.5, 0.5] $ grad (binary f) [1, 1]
f x y = sqrt $ x * y -- grad f [x, y] = [y / (2 * f x y), x / (2 * f x y)]
outside = testGroup "outside" [tdiff, tgrad] where
tdiff = testCase "diff" $ do
assertNearList [nan, 0.0] $ diff (0 `f`) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (1 `f`) <$> [0, 1]
assertNearList [nan, 0.0] $ diff (`f` 0) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (`f` 1) <$> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [nan, nan] $ grad (binary f) [0, 0]
assertNearList [0.0, inf] $ grad (binary f) [1, 0]
assertNearList [inf, 0.0] $ grad (binary f) [0, 1]
assertNearList [0.5, 0.5] $ grad (binary f) [1, 1]
f x y = sqrt x * sqrt y -- grad f [x, y] = [sqrt y / 2 sqrt x, sqrt x / 2 sqrt y]
binary f = \ [x, y] -> f x y

near :: Double -> Double -> Bool
near a b = bothNaN || bothInfinite || abs (a - b) <= 1e-12 where
bothNaN = isNaN a && isNaN b
bothInfinite = signum a == signum b && isInfinite a && isInfinite b

nearList :: [Double] -> [Double] -> Bool
nearList as bs = length as == length bs && and (zipWith near as bs)

nearMatrix :: [[Double]] -> [[Double]] -> Bool
nearMatrix as bs = length as == length bs && and (zipWith nearList as bs)

assertNear :: Double -> Double -> Assertion
assertNear a b = near a b @? expect a b

assertNearList :: [Double] -> [Double] -> Assertion
assertNearList a b = nearList a b @? expect a b

assertNearMatrix :: [[Double]] -> [[Double]] -> Assertion
assertNearMatrix a b = nearMatrix a b @? expect a b

expect :: Show a => a -> a -> String
expect a b = printf "expected %s but got %s" (show a) (show b)

nan :: Double
nan = 0 / 0

nearZero :: (Fractional a, Ord a) => a -> Bool
nearZero a = abs a <= 1e-12
inf :: Double
inf = 1 / 0