From 05093dd1baef3d89181938bb51228622ea6ceb3f Mon Sep 17 00:00:00 2001 From: Richard Patel Date: Wed, 4 Dec 2024 02:32:51 +0000 Subject: [PATCH] Add AVX 16x16 bitwise ops --- src/ballet/reedsol/fd_reedsol_pi.c | 81 ++++++++++++++---------------- src/util/simd/fd_avx_wh.h | 32 ++++++++++++ src/util/simd/fd_avx_ws.h | 35 +++++++++++++ src/util/simd/test_avx_16x16.c | 68 +++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 42 deletions(-) diff --git a/src/ballet/reedsol/fd_reedsol_pi.c b/src/ballet/reedsol/fd_reedsol_pi.c index 7b389af920..44601612b8 100644 --- a/src/ballet/reedsol/fd_reedsol_pi.c +++ b/src/ballet/reedsol/fd_reedsol_pi.c @@ -69,9 +69,6 @@ #include "../../util/simd/fd_sse.h" #define ws_adjust_sign(a,b) _mm256_sign_epi16( (a), (b) ) /* scales elements in a by the sign of the corresponding element of b */ -#define ws_shl(a,imm) _mm256_slli_epi16( (a), (imm) ) -#define ws_and(a,b) _mm256_and_si256( (a), (b) ) -#define ws_shru(a,imm) _mm256_srli_epi16( (a), (imm) ) static inline ws_t ws_mod255( ws_t x ) { @@ -79,7 +76,7 @@ ws_mod255( ws_t x ) { (x%255) == 0xFF & ( x + (x*0x8081)>>23). We need at least 31 bits of precision for the product, so mulh_epu16 is perfect. */ - return ws_and( ws_bcast( 0xFF ), ws_add( x, ws_shru( ws_mulhi( x, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + return ws_and( ws_bcast( 0xFF ), ws_add( x, ws_shru( wh_mulhi( x, ws_bcast( (short)0x8081 ) ), 7 ) ) ); } /* The following macros implement the unscaled Fast Walsh-Hadamard @@ -468,7 +465,7 @@ fd_reedsol_private_gen_pi_16( uchar const * is_erased, (x%255) == 0xFF & ( x + (x*0x8081)>>23). We need at least 31 bits of precision for the product, so mulh_epu16 is perfect. */ - log_pi = ws_and( ws_bcast( 0xFF ), ws_add( log_pi, ws_shru( ws_mulhi( log_pi, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi = ws_and( ws_bcast( 0xFF ), ws_add( log_pi, ws_shru( wh_mulhi( log_pi, ws_bcast( (short)0x8081 ) ), 7 ) ) ); /* Now 0<=log_pi < 255 */ /* Since our FWHT implementation is unscaled, we've computed a value @@ -494,7 +491,7 @@ fd_reedsol_private_gen_pi_16( uchar const * is_erased, ws_t product = ws_mullo( transformed, ws_ld( fwht_l_twiddle_16 ) ); /* Compute mod 255, using the same approach as above. */ - product = ws_and( ws_bcast( 0xFF ), ws_add( product, ws_shru( ws_mulhi( product, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product = ws_and( ws_bcast( 0xFF ), ws_add( product, ws_shru( wh_mulhi( product, ws_bcast( (short)0x8081 ) ), 7 ) ) ); wb_t compact_product = compact_ws( product, ws_zero() ); FD_REEDSOL_FWHT_16( compact_product ); @@ -548,8 +545,8 @@ fd_reedsol_private_gen_pi_32( uchar const * is_erased, (x%255) == 0xFF & ( x + (x*0x8081)>>23). We need at least 31 bits of precision for the product, so mulh_epu16 is perfect. */ - log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); /* Now 0<=log_pi < 255 */ /* Since our FWHT implementation is unscaled, we've computed a value @@ -578,8 +575,8 @@ fd_reedsol_private_gen_pi_32( uchar const * is_erased, ws_t product1 = ws_mullo( transformed1, ws_ld( fwht_l_twiddle_32 + 16UL ) ); /* Compute mod 255, using the same approach as above. */ - product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); wb_t compact_product = compact_ws( product0, product1 ); FD_REEDSOL_FWHT_32( compact_product ); @@ -646,10 +643,10 @@ fd_reedsol_private_gen_pi_64( uchar const * is_erased, (x%255) == 0xFF & ( x + (x*0x8081)>>23). We need at least 31 bits of precision for the product, so mulh_epu16 is perfect. */ - log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( ws_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( ws_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( wh_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( wh_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); /* Now 0<=log_pi < 255 */ /* Since our FWHT implementation is unscaled, we've computed a value @@ -689,10 +686,10 @@ fd_reedsol_private_gen_pi_64( uchar const * is_erased, ws_t product3 = ws_mullo( transformed3, ws_ld( fwht_l_twiddle_64 + 48UL ) ); /* Compute mod 255, using the same approach as above. */ - product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); wb_t compact_product0 = compact_ws( product0, product1 ); wb_t compact_product1 = compact_ws( product2, product3 ); @@ -763,14 +760,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased, product6 = ws_add( product6, ws_bcast( (short)64*255 ) ); product7 = ws_add( product7, ws_bcast( (short)64*255 ) ); - product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( ws_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( ws_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( ws_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( ws_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( wh_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( wh_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( wh_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( wh_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); /* Now 0 <= product < 255 */ @@ -810,14 +807,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased, (x%255) == 0xFF & ( x + (x*0x8081)>>23). We need at least 31 bits of precision for the product, so mulh_epu16 is perfect. */ - log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( ws_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( ws_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi4 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi4, ws_shru( ws_mulhi( log_pi4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi5 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi5, ws_shru( ws_mulhi( log_pi5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi6 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi6, ws_shru( ws_mulhi( log_pi6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - log_pi7 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi7, ws_shru( ws_mulhi( log_pi7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( wh_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( wh_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi4 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi4, ws_shru( wh_mulhi( log_pi4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi5 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi5, ws_shru( wh_mulhi( log_pi5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi6 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi6, ws_shru( wh_mulhi( log_pi6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + log_pi7 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi7, ws_shru( wh_mulhi( log_pi7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); /* Now 0<=log_pi < 255 */ /* Since our FWHT implementation is unscaled, we've computed a value @@ -875,14 +872,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased, ws_t product7 = ws_mullo( transformed7, ws_ld( fwht_l_twiddle_128 + 112UL ) ); /* Compute mod 255, using the same approach as above. */ - product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( ws_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( ws_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( ws_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); - product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( ws_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( wh_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( wh_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( wh_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) ); + product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( wh_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) ); wb_t compact_product0 = compact_ws( product0, product1 ); wb_t compact_product1 = compact_ws( product2, product3 ); wb_t compact_product2 = compact_ws( product4, product5 ); diff --git a/src/util/simd/fd_avx_wh.h b/src/util/simd/fd_avx_wh.h index d3c11d0c83..29be84f60a 100644 --- a/src/util/simd/fd_avx_wh.h +++ b/src/util/simd/fd_avx_wh.h @@ -89,6 +89,38 @@ wh_insert_variable( wh_t a, int n, ushort v ) { #define wh_mulhi(a,b) _mm256_mulhi_epu16( (a), (b) ) /* [ (a0*b0)>>16 (a1*b1)>>16 ... (a15*b15)>>16 ] */ #define wh_mul(a,b) wh_mullo((a),(b)) +/* Binary operations */ + +/* Note: wh_shl/wh_shr is an unsigned left/right shift by imm bits; imm + must be a compile time constant in [0,15]. The variable variants are + slower but do not require the shift amount to be known at compile + time (should still be in [0,15]). */ + +#define wh_not(a) _mm256_xor_si256( _mm256_set1_epi16( -1 ), (a) ) /* [ ~a0 ~a1 ... ~a15 ] */ + +#define wh_shl(a,imm) _mm256_slli_epi16( (a), (imm) ) /* [ a0<>imm a1>>imm ... a15>>imm ] */ + +#define wh_shl_variable(a,n) _mm256_sll_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) ) +#define wh_shr_variable(a,n) _mm256_srl_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) ) + +#define wh_shl_vector(a,b) _mm256_sllv_epi16( (a), (b) ) /* [ a0<>b0 a1>>b1 ... a15>>b15 ] */ + +#define wh_and(a,b) _mm256_and_si256( (a), (b) ) /* [ a0 &b0 a1& b1 ... a15& b15 ] */ +#define wh_andnot(a,b) _mm256_andnot_si256( (a), (b) ) /* [ (~a0)&b0 (~a1)&b1 ... (~a15)&b15 ] */ +#define wh_or(a,b) _mm256_or_si256( (a), (b) ) /* [ a0 |b0 a1 |b1 ... a15 |b15 ] */ +#define wh_xor(a,b) _mm256_xor_si256( (a), (b) ) /* [ a0 ^b0 a1 ^b1 ... a15 ^b15 ] */ + +/* wh_rol(x,n) returns wh( rotate_left (x0,n), rotate_left (x1,n), ... ) + wh_ror(x,n) returns wh( rotate_right(x0,n), rotate_right(x1,n), ... ) */ + +static inline wh_t wh_rol( wh_t a, int imm ) { return wh_or( wh_shl( a, imm & 15 ), wh_shr( a, (-imm) & 15 ) ); } +static inline wh_t wh_ror( wh_t a, int imm ) { return wh_or( wh_shr( a, imm & 15 ), wh_shl( a, (-imm) & 15 ) ); } + +static inline wh_t wh_rol_variable( wh_t a, int n ) { return wh_or( wh_shl_variable( a, n&15 ), wh_shr_variable( a, (-n)&15 ) ); } +static inline wh_t wh_ror_variable( wh_t a, int n ) { return wh_or( wh_shr_variable( a, n&15 ), wh_shl_variable( a, (-n)&15 ) ); } + /* Logical operations */ #define wh_eq(a,b) _mm256_cmpeq_epi16( (a), (b) ) /* [ a0==b0 a1==b1 ... a15==b15 ] */ diff --git a/src/util/simd/fd_avx_ws.h b/src/util/simd/fd_avx_ws.h index 2ca00ea8cf..ead7358f73 100644 --- a/src/util/simd/fd_avx_ws.h +++ b/src/util/simd/fd_avx_ws.h @@ -89,6 +89,41 @@ ws_insert_variable( ws_t a, int n, short v ) { #define ws_mulhi(a,b) _mm256_mulhi_epi16( (a), (b) ) /* [ (a0*b0)>>16 (a1*b1)>>16 ... (a15*b15)>>16 ] */ #define ws_mul(a,b) ws_mullo((a),(b)) +/* Binary operations */ + +/* Note: ws_hl/ws_shr/ws_shru is a left/signed right/unsigned right + shift by imm bits; imm must be a compile tiem constant in [0,15]. + The variable variants are slower but do not require the shift amount + to be known at compile time (should still be in [0,15]). */ + +#define ws_not(a) _mm256_xor_si256( _mm256_set1_epi16( -1 ), (a) ) /* [ ~a0 ~a1 ... ~a7 ] */ + +#define ws_shl(a,imm) _mm256_slli_epi16( (a), (imm) ) /* [ a0<>imm a1>>imm ... a7>>imm ] (treat a as signed)*/ +#define ws_shru(a,imm) _mm256_srli_epi16( (a), (imm) ) /* [ a0>>imm a1>>imm ... a7>>imm ] (treat a as unsigned) */ + +#define ws_shl_variable(a,n) _mm256_sll_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) ) +#define ws_shr_variable(a,n) _mm256_sra_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) ) +#define ws_shru_variable(a,n) _mm256_srl_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) ) + +#define ws_shl_vector(a,b) _mm256_sllv_epi16( (a), (b) ) /* [ a0<>b0 a1>>b1 ... a7>>b7 ] (treat a as signed) */ +#define ws_shru_vector(a,b) _mm256_srlv_epi16( (a), (b) ) /* [ a0>>b0 a1>>b1 ... a7>>b7 ] (treat a as unsigned) */ + +#define ws_and(a,b) _mm256_and_si256( (a), (b) ) /* [ a0 &b0 a1& b1 ... a7& b7 ] */ +#define ws_andnot(a,b) _mm256_andnot_si256( (a), (b) ) /* [ (~a0)&b0 (~a1)&b1 ... (~a7)&b7 ] */ +#define ws_or(a,b) _mm256_or_si256( (a), (b) ) /* [ a0 |b0 a1 |b1 ... a7 |b7 ] */ +#define ws_xor(a,b) _mm256_xor_si256( (a), (b) ) /* [ a0 ^b0 a1 ^b1 ... a7 ^b7 ] */ + +/* ws_rol(x,n) returns ws( rotate_left (x0,n), rotate_left (x1,n), ... ) + ws_ror(x,n) returns ws( rotate_right(x0,n), rotate_right(x1,n), ... ) */ + +static inline ws_t ws_rol( ws_t a, int imm ) { return ws_or( ws_shl( a, imm & 15 ), ws_shru( a, (-imm) & 15 ) ); } +static inline ws_t ws_ror( ws_t a, int imm ) { return ws_or( ws_shru( a, imm & 15 ), ws_shl( a, (-imm) & 15 ) ); } + +static inline ws_t ws_rol_variable( ws_t a, int n ) { return ws_or( ws_shl_variable( a, n&15 ), ws_shru_variable( a, (-n)&15 ) ); } +static inline ws_t ws_ror_variable( ws_t a, int n ) { return ws_or( ws_shru_variable( a, n&15 ), ws_shl_variable( a, (-n)&15 ) ); } + /* Logical operations */ #define ws_eq(a,b) _mm256_cmpeq_epi16( (a), (b) ) /* [ a0==b0 a1==b1 ... a15==b15 ] */ diff --git a/src/util/simd/test_avx_16x16.c b/src/util/simd/test_avx_16x16.c index 8b87a87fbc..f76e52691b 100644 --- a/src/util/simd/test_avx_16x16.c +++ b/src/util/simd/test_avx_16x16.c @@ -57,6 +57,42 @@ main( int argc, /* */ FD_TEST( ws_test( ws_mullo( x, y ), si ) ); INIT_SI( (short)((((int)xi[j])*((int)yi[j]))>>16) ); FD_TEST( ws_test( ws_mulhi( x, y ), si ) ); + /* Bit operations */ + + INIT_SI( (short)~yi[j] ); FD_TEST( ws_test( ws_not( y ), si ) ); + +# define SHRU(x,n) (short)( (ushort)(x)>>(n) ) +# define ROL(x,n) fd_short_rotate_left ( (x), (n) ) +# define ROR(x,n) fd_short_rotate_right( (x), (n) ) + +# define _(n) \ + INIT_SI( (short)(yi[j]<>n) ); FD_TEST( ws_test( ws_shr( y, n ), si ) ); \ + INIT_SI( SHRU( yi[j], n ) ); FD_TEST( ws_test( ws_shru( y, n ), si ) ); \ + INIT_SI( ROL( yi[j], n ) ); FD_TEST( ws_test( ws_rol( y, n ), si ) ); \ + INIT_SI( ROR( yi[j], n ) ); FD_TEST( ws_test( ws_ror( y, n ), si ) ) + _( 0); _( 1); _( 2); _( 3); _( 4); _( 5); _( 6); _( 7); + _( 8); _( 9); _(10); _(11); _(12); _(13); _(14); _(15); +# undef _ + + for( int n=0; n<16; n++ ) { + int volatile m[1]; m[0] = n; + INIT_SI( (short)(yi[j]<>n ) ); FD_TEST( ws_test( ws_shr_variable( y, m[0] ), si ) ); + INIT_SI( SHRU( yi[j], n ) ); FD_TEST( ws_test( ws_shru_variable( y, m[0] ), si ) ); + INIT_SI( ROL( yi[j], n ) ); FD_TEST( ws_test( ws_rol_variable( y, m[0] ), si ) ); + INIT_SI( ROR( yi[j], n ) ); FD_TEST( ws_test( ws_ror_variable( y, m[0] ), si ) ); + } + +# undef SHRU +# undef ROR +# undef ROL + + INIT_SI( xi[j] & yi[j] ); FD_TEST( ws_test( ws_and( x, y ), si ) ); + INIT_SI( ((short)~xi[j]) & yi[j] ); FD_TEST( ws_test( ws_andnot( x, y ), si ) ); + INIT_SI( xi[j] | yi[j] ); FD_TEST( ws_test( ws_or( x, y ), si ) ); + INIT_SI( xi[j] ^ yi[j] ); FD_TEST( ws_test( ws_xor( x, y ), si ) ); + /* Logical operations */ /* TODO: eliminate this hack (see note in fd_avx_wc.h about @@ -109,6 +145,38 @@ main( int argc, /* */ FD_TEST( wh_test( wh_mullo( x, y ), hj ) ); INIT_HJ( (ushort)((((uint)xi[j])*((uint)yi[j]))>>16) ); FD_TEST( wh_test( wh_mulhi( x, y ), hj ) ); + /* Bit operations */ + + INIT_HJ( (ushort)~yi[j] ); FD_TEST( wh_test( wh_not( y ), hj ) ); + +# define ROL(x,n) fd_ushort_rotate_left ( (x), (n) ) +# define ROR(x,n) fd_ushort_rotate_right( (x), (n) ) + +# define _(n) \ + INIT_HJ( (ushort)(yi[j]<>n) ); FD_TEST( wh_test( wh_shr( y, n ), hj ) ); \ + INIT_HJ( ROL( yi[j], n ) ); FD_TEST( wh_test( wh_rol( y, n ), hj ) ); \ + INIT_HJ( ROR( yi[j], n ) ); FD_TEST( wh_test( wh_ror( y, n ), hj ) ) + _( 0); _( 1); _( 2); _( 3); _( 4); _( 5); _( 6); _( 7); + _( 8); _( 9); _(10); _(11); _(12); _(13); _(14); _(15); +# undef _ + + for( int n=0; n<16; n++ ) { + int volatile m[1]; m[0] = n; + INIT_HJ( (ushort)(yi[j]<>n ) ); FD_TEST( wh_test( wh_shr_variable( y, m[0] ), hj ) ); + INIT_HJ( ROL( yi[j], n ) ); FD_TEST( wh_test( wh_rol_variable( y, m[0] ), hj ) ); + INIT_HJ( ROR( yi[j], n ) ); FD_TEST( wh_test( wh_ror_variable( y, m[0] ), hj ) ); + } + +# undef ROR +# undef ROL + + INIT_HJ( xi[j] & yi[j] ); FD_TEST( wh_test( wh_and( x, y ), hj ) ); + INIT_HJ( ((ushort)~xi[j]) & yi[j] ); FD_TEST( wh_test( wh_andnot( x, y ), hj ) ); + INIT_HJ( xi[j] | yi[j] ); FD_TEST( wh_test( wh_or( x, y ), hj ) ); + INIT_HJ( xi[j] ^ yi[j] ); FD_TEST( wh_test( wh_xor( x, y ), hj ) ); + /* Logical operations */ /* TODO: eliminate this hack (see note in fd_avx_wc.h about