Skip to content

Commit

Permalink
math: add safe casts to log function
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 20, 2023
1 parent b9fbc93 commit c076862
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
From 246137e8768c1a7c9c74e76d494ae764c40fc85e Mon Sep 17 00:00:00 2001
From: Viviane Potocnik <[email protected]>
Date: Fri, 20 Oct 2023 21:38:16 +0200
Subject: [PATCH] math: add safe casts to log function

---
src/math/log2.c | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/src/math/log2.c b/src/math/log2.c
index 1276ed4e..a0c669cf 100644
--- a/src/math/log2.c
+++ b/src/math/log2.c
@@ -22,7 +22,7 @@
/* Top 16 bits of a double. */
static inline uint32_t top16(double x)
{
- return asuint64(x) >> 48;
+ return asuint64_safe(x) >> 48;
}

double log2(double x)
@@ -32,14 +32,14 @@ double log2(double x)
uint32_t top;
int k, i;

- ix = asuint64(x);
+ ix = asuint64_safe(x);
top = top16(x);
-#define LO asuint64(1.0 - 0x1.5b51p-5)
-#define HI asuint64(1.0 + 0x1.6ab2p-5)
+#define LO asuint64_safe(1.0 - 0x1.5b51p-5)
+#define HI asuint64_safe(1.0 + 0x1.6ab2p-5)
if (predict_false(ix - LO < HI - LO)) {
/* Handle close to 1.0 inputs separately. */
/* Fix sign of zero with downward rounding when x==1. */
- if (WANT_ROUNDING && predict_false(ix == asuint64(1.0)))
+ if (WANT_ROUNDING && predict_false(ix == asuint64_safe(1.0)))
return 0;
r = x - 1.0;
#if __FP_FAST_FMA
@@ -47,7 +47,7 @@ double log2(double x)
lo = r * InvLn2lo + __builtin_fma(r, InvLn2hi, -hi);
#else
double_t rhi, rlo;
- rhi = asdouble(asuint64(r) & -1ULL << 32);
+ rhi = asdouble_safe(asuint64_safe(r) & -1ULL << 32);
rlo = r - rhi;
hi = rhi * InvLn2hi;
lo = rlo * InvLn2hi + r * InvLn2lo;
@@ -67,12 +67,12 @@ double log2(double x)
/* x < 0x1p-1022 or inf or nan. */
if (ix * 2 == 0)
return __math_divzero(1);
- if (ix == asuint64(INFINITY)) /* log(inf) == inf. */
+ if (ix == asuint64_safe(INFINITY)) /* log(inf) == inf. */
return x;
if ((top & 0x8000) || (top & 0x7ff0) == 0x7ff0)
return __math_invalid(x);
/* x is subnormal, normalize it. */
- ix = asuint64(x * 0x1p52);
+ ix = asuint64_safe(x * 0x1p52);
ix -= 52ULL << 52;
}

@@ -85,7 +85,7 @@ double log2(double x)
iz = ix - (tmp & 0xfffULL << 52);
invc = T[i].invc;
logc = T[i].logc;
- z = asdouble(iz);
+ z = asdouble_safe(iz);
kd = (double_t)k;

/* log2(x) = log2(z/c) + log2(c) + k. */
@@ -99,7 +99,7 @@ double log2(double x)
double_t rhi, rlo;
/* rounding error: 0x1p-55/N + 0x1p-65. */
r = (z - T2[i].chi - T2[i].clo) * invc;
- rhi = asdouble(asuint64(r) & -1ULL << 32);
+ rhi = asdouble_safe(asuint64_safe(r) & -1ULL << 32);
rlo = r - rhi;
t1 = rhi * InvLn2hi;
t2 = rlo * InvLn2hi + r * InvLn2lo;
--
2.31.1

20 changes: 10 additions & 10 deletions sw/math/src/math/log2.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
/* Top 16 bits of a double. */
static inline uint32_t top16(double x)
{
return asuint64(x) >> 48;
return asuint64_safe(x) >> 48;
}

double log2(double x)
Expand All @@ -32,22 +32,22 @@ double log2(double x)
uint32_t top;
int k, i;

ix = asuint64(x);
ix = asuint64_safe(x);
top = top16(x);
#define LO asuint64(1.0 - 0x1.5b51p-5)
#define HI asuint64(1.0 + 0x1.6ab2p-5)
#define LO asuint64_safe(1.0 - 0x1.5b51p-5)
#define HI asuint64_safe(1.0 + 0x1.6ab2p-5)
if (predict_false(ix - LO < HI - LO)) {
/* Handle close to 1.0 inputs separately. */
/* Fix sign of zero with downward rounding when x==1. */
if (WANT_ROUNDING && predict_false(ix == asuint64(1.0)))
if (WANT_ROUNDING && predict_false(ix == asuint64_safe(1.0)))
return 0;
r = x - 1.0;
#if __FP_FAST_FMA
hi = r * InvLn2hi;
lo = r * InvLn2lo + __builtin_fma(r, InvLn2hi, -hi);
#else
double_t rhi, rlo;
rhi = asdouble(asuint64(r) & -1ULL << 32);
rhi = asdouble_safe(asuint64_safe(r) & -1ULL << 32);
rlo = r - rhi;
hi = rhi * InvLn2hi;
lo = rlo * InvLn2hi + r * InvLn2lo;
Expand All @@ -67,12 +67,12 @@ double log2(double x)
/* x < 0x1p-1022 or inf or nan. */
if (ix * 2 == 0)
return __math_divzero(1);
if (ix == asuint64(INFINITY)) /* log(inf) == inf. */
if (ix == asuint64_safe(INFINITY)) /* log(inf) == inf. */
return x;
if ((top & 0x8000) || (top & 0x7ff0) == 0x7ff0)
return __math_invalid(x);
/* x is subnormal, normalize it. */
ix = asuint64(x * 0x1p52);
ix = asuint64_safe(x * 0x1p52);
ix -= 52ULL << 52;
}

Expand All @@ -85,7 +85,7 @@ double log2(double x)
iz = ix - (tmp & 0xfffULL << 52);
invc = T[i].invc;
logc = T[i].logc;
z = asdouble(iz);
z = asdouble_safe(iz);
kd = (double_t)k;

/* log2(x) = log2(z/c) + log2(c) + k. */
Expand All @@ -99,7 +99,7 @@ double log2(double x)
double_t rhi, rlo;
/* rounding error: 0x1p-55/N + 0x1p-65. */
r = (z - T2[i].chi - T2[i].clo) * invc;
rhi = asdouble(asuint64(r) & -1ULL << 32);
rhi = asdouble_safe(asuint64_safe(r) & -1ULL << 32);
rlo = r - rhi;
t1 = rhi * InvLn2hi;
t2 = rlo * InvLn2hi + r * InvLn2lo;
Expand Down

0 comments on commit c076862

Please sign in to comment.