Skip to content

Commit

Permalink
sw/math: Implement safe tanh function
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Nov 8, 2023
1 parent 4a110f0 commit 9fe8d86
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 13 deletions.
149 changes: 149 additions & 0 deletions sw/deps/patches/musl/0004-sw-math-Implement-safe-tanh-function.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
From b419b07facc9591ba0d8683f53c9adefb8a9b0c6 Mon Sep 17 00:00:00 2001
From: Luca Colagrande <[email protected]>
Date: Wed, 8 Nov 2023 09:35:17 +0100
Subject: [PATCH] `sw/math`: Implement safe `tanh` function

---
src/internal/libm.h | 31 +++++++++++++++++++++++++++++++
src/math/expm1.c | 34 ++++++++++++++++++++++++++--------
src/math/tanh.c | 17 ++++++++++++-----
3 files changed, 69 insertions(+), 13 deletions(-)

diff --git a/src/internal/libm.h b/src/internal/libm.h
index 60b9866..c96c0ec 100644
--- a/src/internal/libm.h
+++ b/src/internal/libm.h
@@ -96,6 +96,37 @@ static int32_t converttoint(double_t);
#define predict_false(x) (x)
#endif

+/* Memory-consistent functions to manipulate the upper word of a
+ double-precision floating-point number in the integer core.
+ Since there is no dedicated instruction to move the upper 32-bits
+ of a double-precision floating point register to an integer register
+ the compiler resorts to moving the value through the memory. However in
+ Snitch neither the program ordering between floating-point and integer
+ instructions is guaranteed, nor is memory consistency between the integer
+ and floating-point threads. */
+
+static inline uint32_t safe_extract_upper_32b_from_double(double x) {
+ double f;
+ uint32_t result;
+ asm volatile("fsd %[x], 0(%[ptr]) \n"
+ "fld ft3, 0(%[ptr]) \n"
+ "fmv.x.w t0, ft3 \n"
+ "mv t0, t0 \n"
+ "lw %[result], 4(%[ptr]) \n"
+ : [result]"=r"(result) : [x]"f"(x), [ptr]"r"(&f): "ft3", "t0", "memory");
+ return result;
+}
+
+static inline void safe_inject_into_upper_32b_double(uint32_t x, double *f) {
+ asm volatile("sw %[x], 4(%[ptr]) \n"
+ "lw %[x], 4(%[ptr]) \n"
+ "fmv.w.x ft3, %[x] \n"
+ : : [x]"r"(x), [ptr]"r"(f): "ft3", "memory");
+}
+
+/* TODO: the following functions are not really safe, compare previous two
+ functions */
+
/* FPU fence to synchronize the FPU and integer core in Snitch. */
inline void snrt_fpu_fence() {
unsigned tmp;
diff --git a/src/math/expm1.c b/src/math/expm1.c
index ac1e61e..d94f57f 100644
--- a/src/math/expm1.c
+++ b/src/math/expm1.c
@@ -121,9 +121,14 @@ Q5 = -2.01099218183624371326e-07; /* BE8AFDB7 6E09C32D */
double expm1(double x)
{
double_t y,hi,lo,c,t,e,hxs,hfx,r1,twopk;
- union {double f; uint64_t i;} u = {x};
- uint32_t hx = u.i>>32 & 0x7fffffff;
- int k, sign = u.i>>63;
+ /// Original implementation
+ // union {double f; uint64_t i;} u = {x};
+ // uint32_t hx = u.i>>32 & 0x7fffffff;
+ // int k, sign = u.i>>63;
+ /// Safe implementation in Snitch
+ uint32_t upper_32b_x = safe_extract_upper_32b_from_double(x);
+ uint32_t hx = upper_32b_x & 0x7fffffff;
+ int k, sign = upper_32b_x>>31;

/* filter out huge and non-finite argument */
if (hx >= 0x4043687A) { /* if |x|>=56*ln2 */
@@ -182,8 +187,12 @@ double expm1(double x)
return -2.0*(e-(x+0.5));
return 1.0+2.0*(x-e);
}
- u.i = (uint64_t)(0x3ff + k)<<52; /* 2^k */
- twopk = u.f;
+ /// Original implementation
+ // u.i = (uint64_t)(0x3ff + k)<<52; /* 2^k */
+ // twopk = u.f;
+ /// Safe implementation in Snitch
+ uint32_t u_i = (uint32_t)(0x3ff + k)<<20;
+ safe_inject_into_upper_32b_double(u_i, &twopk);
if (k < 0 || k > 56) { /* suffice to return exp(x)-1 */
y = x - e + 1.0;
if (k == 1024)
@@ -192,10 +201,19 @@ double expm1(double x)
y = y*twopk;
return y - 1.0;
}
- u.i = (uint64_t)(0x3ff - k)<<52; /* 2^-k */
+ /// Original implementation
+ // u.i = (uint64_t)(0x3ff - k)<<52; /* 2^-k */
+ // if (k < 20)
+ // y = (x-e+(1-u.f))*twopk;
+ // else
+ // y = (x-(e+u.f)+1)*twopk;
+ /// Safe implementation in Snitch
+ u_i = (uint32_t)(0x3ff - k)<<20;
+ double u_f = 0;
+ safe_inject_into_upper_32b_double(u_i, &u_f);
if (k < 20)
- y = (x-e+(1-u.f))*twopk;
+ y = (x-e+(1-u_f))*twopk;
else
- y = (x-(e+u.f)+1)*twopk;
+ y = (x-(e+u_f)+1)*twopk;
return y;
}
diff --git a/src/math/tanh.c b/src/math/tanh.c
index 20d6dbc..2481db1 100644
--- a/src/math/tanh.c
+++ b/src/math/tanh.c
@@ -6,16 +6,23 @@
*/
double tanh(double x)
{
- union {double f; uint64_t i;} u = {.f = x};
uint32_t w;
int sign;
double_t t;

/* x = |x| */
- sign = u.i >> 63;
- u.i &= (uint64_t)-1/2;
- x = u.f;
- w = u.i >> 32;
+ /// Original implementation
+ // union {double f; uint64_t i;} u = {.f = x};
+ // sign = u.i >> 63;
+ // u.i &= (uint64_t)-1/2;
+ // x = u.f;
+ // w = u.i >> 32;
+ /// Safe implementation in Snitch
+ uint32_t upper_32b_x = safe_extract_upper_32b_from_double(x);
+ sign = upper_32b_x >> 31;
+ uint32_t sign_mask = (~(1 << 31));
+ w = upper_32b_x & sign_mask;
+ safe_inject_into_upper_32b_double(w, &x);

if (w > 0x3fe193ea) {
/* |x| > log(3)/2 ~= 0.5493 or nan */
--
2.28.0

31 changes: 31 additions & 0 deletions sw/math/src/internal/libm.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,37 @@ static int32_t converttoint(double_t);
#define predict_false(x) (x)
#endif

/* Memory-consistent functions to manipulate the upper word of a
double-precision floating-point number in the integer core.
Since there is no dedicated instruction to move the upper 32-bits
of a double-precision floating point register to an integer register
the compiler resorts to moving the value through the memory. However in
Snitch neither the program ordering between floating-point and integer
instructions is guaranteed, nor is memory consistency between the integer
and floating-point threads. */

static inline uint32_t safe_extract_upper_32b_from_double(double x) {
double f;
uint32_t result;
asm volatile("fsd %[x], 0(%[ptr]) \n"
"fld ft3, 0(%[ptr]) \n"
"fmv.x.w t0, ft3 \n"
"mv t0, t0 \n"
"lw %[result], 4(%[ptr]) \n"
: [result]"=r"(result) : [x]"f"(x), [ptr]"r"(&f): "ft3", "t0", "memory");
return result;
}

static inline void safe_inject_into_upper_32b_double(uint32_t x, double *f) {
asm volatile("sw %[x], 4(%[ptr]) \n"
"lw %[x], 4(%[ptr]) \n"
"fmv.w.x ft3, %[x] \n"
: : [x]"r"(x), [ptr]"r"(f): "ft3", "memory");
}

/* TODO: the following functions are not really safe, compare previous two
functions */

/* FPU fence to synchronize the FPU and integer core in Snitch. */
inline void snrt_fpu_fence() {
unsigned tmp;
Expand Down
34 changes: 26 additions & 8 deletions sw/math/src/math/expm1.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ Q5 = -2.01099218183624371326e-07; /* BE8AFDB7 6E09C32D */
double expm1(double x)
{
double_t y,hi,lo,c,t,e,hxs,hfx,r1,twopk;
union {double f; uint64_t i;} u = {x};
uint32_t hx = u.i>>32 & 0x7fffffff;
int k, sign = u.i>>63;
/// Original implementation
// union {double f; uint64_t i;} u = {x};
// uint32_t hx = u.i>>32 & 0x7fffffff;
// int k, sign = u.i>>63;
/// Safe implementation in Snitch
uint32_t upper_32b_x = safe_extract_upper_32b_from_double(x);
uint32_t hx = upper_32b_x & 0x7fffffff;
int k, sign = upper_32b_x>>31;

/* filter out huge and non-finite argument */
if (hx >= 0x4043687A) { /* if |x|>=56*ln2 */
Expand Down Expand Up @@ -182,8 +187,12 @@ double expm1(double x)
return -2.0*(e-(x+0.5));
return 1.0+2.0*(x-e);
}
u.i = (uint64_t)(0x3ff + k)<<52; /* 2^k */
twopk = u.f;
/// Original implementation
// u.i = (uint64_t)(0x3ff + k)<<52; /* 2^k */
// twopk = u.f;
/// Safe implementation in Snitch
uint32_t u_i = (uint32_t)(0x3ff + k)<<20;
safe_inject_into_upper_32b_double(u_i, &twopk);
if (k < 0 || k > 56) { /* suffice to return exp(x)-1 */
y = x - e + 1.0;
if (k == 1024)
Expand All @@ -192,10 +201,19 @@ double expm1(double x)
y = y*twopk;
return y - 1.0;
}
u.i = (uint64_t)(0x3ff - k)<<52; /* 2^-k */
/// Original implementation
// u.i = (uint64_t)(0x3ff - k)<<52; /* 2^-k */
// if (k < 20)
// y = (x-e+(1-u.f))*twopk;
// else
// y = (x-(e+u.f)+1)*twopk;
/// Safe implementation in Snitch
u_i = (uint32_t)(0x3ff - k)<<20;
double u_f = 0;
safe_inject_into_upper_32b_double(u_i, &u_f);
if (k < 20)
y = (x-e+(1-u.f))*twopk;
y = (x-e+(1-u_f))*twopk;
else
y = (x-(e+u.f)+1)*twopk;
y = (x-(e+u_f)+1)*twopk;
return y;
}
17 changes: 12 additions & 5 deletions sw/math/src/math/tanh.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,23 @@
*/
double tanh(double x)
{
union {double f; uint64_t i;} u = {.f = x};
uint32_t w;
int sign;
double_t t;

/* x = |x| */
sign = u.i >> 63;
u.i &= (uint64_t)-1/2;
x = u.f;
w = u.i >> 32;
/// Original implementation
// union {double f; uint64_t i;} u = {.f = x};
// sign = u.i >> 63;
// u.i &= (uint64_t)-1/2;
// x = u.f;
// w = u.i >> 32;
/// Safe implementation in Snitch
uint32_t upper_32b_x = safe_extract_upper_32b_from_double(x);
sign = upper_32b_x >> 31;
uint32_t sign_mask = (~(1 << 31));
w = upper_32b_x & sign_mask;
safe_inject_into_upper_32b_double(w, &x);

if (w > 0x3fe193ea) {
/* |x| > log(3)/2 ~= 0.5493 or nan */
Expand Down

0 comments on commit 9fe8d86

Please sign in to comment.