Skip to content

Commit

Permalink
Add AVX 16x16 bitwise ops
Browse files Browse the repository at this point in the history
  • Loading branch information
riptl authored and ripatel-fd committed Dec 6, 2024
1 parent 7e6d477 commit 05093dd
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 42 deletions.
81 changes: 39 additions & 42 deletions src/ballet/reedsol/fd_reedsol_pi.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@
#include "../../util/simd/fd_sse.h"

#define ws_adjust_sign(a,b) _mm256_sign_epi16( (a), (b) ) /* scales elements in a by the sign of the corresponding element of b */
#define ws_shl(a,imm) _mm256_slli_epi16( (a), (imm) )
#define ws_and(a,b) _mm256_and_si256( (a), (b) )
#define ws_shru(a,imm) _mm256_srli_epi16( (a), (imm) )

static inline ws_t
ws_mod255( ws_t x ) {
/* GCC informs me that for a ushort x,
(x%255) == 0xFF & ( x + (x*0x8081)>>23).
We need at least 31 bits of precision for the product, so
mulh_epu16 is perfect. */
return ws_and( ws_bcast( 0xFF ), ws_add( x, ws_shru( ws_mulhi( x, ws_bcast( (short)0x8081 ) ), 7 ) ) );
return ws_and( ws_bcast( 0xFF ), ws_add( x, ws_shru( wh_mulhi( x, ws_bcast( (short)0x8081 ) ), 7 ) ) );
}

/* The following macros implement the unscaled Fast Walsh-Hadamard
Expand Down Expand Up @@ -468,7 +465,7 @@ fd_reedsol_private_gen_pi_16( uchar const * is_erased,
(x%255) == 0xFF & ( x + (x*0x8081)>>23).
We need at least 31 bits of precision for the product, so
mulh_epu16 is perfect. */
log_pi = ws_and( ws_bcast( 0xFF ), ws_add( log_pi, ws_shru( ws_mulhi( log_pi, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi = ws_and( ws_bcast( 0xFF ), ws_add( log_pi, ws_shru( wh_mulhi( log_pi, ws_bcast( (short)0x8081 ) ), 7 ) ) );
/* Now 0<=log_pi < 255 */

/* Since our FWHT implementation is unscaled, we've computed a value
Expand All @@ -494,7 +491,7 @@ fd_reedsol_private_gen_pi_16( uchar const * is_erased,
ws_t product = ws_mullo( transformed, ws_ld( fwht_l_twiddle_16 ) );

/* Compute mod 255, using the same approach as above. */
product = ws_and( ws_bcast( 0xFF ), ws_add( product, ws_shru( ws_mulhi( product, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product = ws_and( ws_bcast( 0xFF ), ws_add( product, ws_shru( wh_mulhi( product, ws_bcast( (short)0x8081 ) ), 7 ) ) );
wb_t compact_product = compact_ws( product, ws_zero() );

FD_REEDSOL_FWHT_16( compact_product );
Expand Down Expand Up @@ -548,8 +545,8 @@ fd_reedsol_private_gen_pi_32( uchar const * is_erased,
(x%255) == 0xFF & ( x + (x*0x8081)>>23).
We need at least 31 bits of precision for the product, so
mulh_epu16 is perfect. */
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
/* Now 0<=log_pi < 255 */

/* Since our FWHT implementation is unscaled, we've computed a value
Expand Down Expand Up @@ -578,8 +575,8 @@ fd_reedsol_private_gen_pi_32( uchar const * is_erased,
ws_t product1 = ws_mullo( transformed1, ws_ld( fwht_l_twiddle_32 + 16UL ) );

/* Compute mod 255, using the same approach as above. */
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
wb_t compact_product = compact_ws( product0, product1 );

FD_REEDSOL_FWHT_32( compact_product );
Expand Down Expand Up @@ -646,10 +643,10 @@ fd_reedsol_private_gen_pi_64( uchar const * is_erased,
(x%255) == 0xFF & ( x + (x*0x8081)>>23).
We need at least 31 bits of precision for the product, so
mulh_epu16 is perfect. */
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( ws_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( ws_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( wh_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( wh_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
/* Now 0<=log_pi < 255 */

/* Since our FWHT implementation is unscaled, we've computed a value
Expand Down Expand Up @@ -689,10 +686,10 @@ fd_reedsol_private_gen_pi_64( uchar const * is_erased,
ws_t product3 = ws_mullo( transformed3, ws_ld( fwht_l_twiddle_64 + 48UL ) );

/* Compute mod 255, using the same approach as above. */
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );

wb_t compact_product0 = compact_ws( product0, product1 );
wb_t compact_product1 = compact_ws( product2, product3 );
Expand Down Expand Up @@ -763,14 +760,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased,
product6 = ws_add( product6, ws_bcast( (short)64*255 ) );
product7 = ws_add( product7, ws_bcast( (short)64*255 ) );

product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( ws_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( ws_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( ws_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( ws_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( wh_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( wh_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( wh_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( wh_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) );

/* Now 0 <= product < 255 */

Expand Down Expand Up @@ -810,14 +807,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased,
(x%255) == 0xFF & ( x + (x*0x8081)>>23).
We need at least 31 bits of precision for the product, so
mulh_epu16 is perfect. */
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( ws_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( ws_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( ws_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( ws_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi4 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi4, ws_shru( ws_mulhi( log_pi4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi5 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi5, ws_shru( ws_mulhi( log_pi5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi6 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi6, ws_shru( ws_mulhi( log_pi6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi7 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi7, ws_shru( ws_mulhi( log_pi7, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi0 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi0, ws_shru( wh_mulhi( log_pi0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi1 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi1, ws_shru( wh_mulhi( log_pi1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi2 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi2, ws_shru( wh_mulhi( log_pi2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi3 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi3, ws_shru( wh_mulhi( log_pi3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi4 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi4, ws_shru( wh_mulhi( log_pi4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi5 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi5, ws_shru( wh_mulhi( log_pi5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi6 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi6, ws_shru( wh_mulhi( log_pi6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
log_pi7 = ws_and( ws_bcast( 0xFF ), ws_add( log_pi7, ws_shru( wh_mulhi( log_pi7, ws_bcast( (short)0x8081 ) ), 7 ) ) );
/* Now 0<=log_pi < 255 */

/* Since our FWHT implementation is unscaled, we've computed a value
Expand Down Expand Up @@ -875,14 +872,14 @@ fd_reedsol_private_gen_pi_128( uchar const * is_erased,
ws_t product7 = ws_mullo( transformed7, ws_ld( fwht_l_twiddle_128 + 112UL ) );

/* Compute mod 255, using the same approach as above. */
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( ws_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( ws_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( ws_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( ws_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( ws_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( ws_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( ws_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( ws_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product0 = ws_and( ws_bcast( 0xFF ), ws_add( product0, ws_shru( wh_mulhi( product0, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product1 = ws_and( ws_bcast( 0xFF ), ws_add( product1, ws_shru( wh_mulhi( product1, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product2 = ws_and( ws_bcast( 0xFF ), ws_add( product2, ws_shru( wh_mulhi( product2, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product3 = ws_and( ws_bcast( 0xFF ), ws_add( product3, ws_shru( wh_mulhi( product3, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product4 = ws_and( ws_bcast( 0xFF ), ws_add( product4, ws_shru( wh_mulhi( product4, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product5 = ws_and( ws_bcast( 0xFF ), ws_add( product5, ws_shru( wh_mulhi( product5, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product6 = ws_and( ws_bcast( 0xFF ), ws_add( product6, ws_shru( wh_mulhi( product6, ws_bcast( (short)0x8081 ) ), 7 ) ) );
product7 = ws_and( ws_bcast( 0xFF ), ws_add( product7, ws_shru( wh_mulhi( product7, ws_bcast( (short)0x8081 ) ), 7 ) ) );
wb_t compact_product0 = compact_ws( product0, product1 );
wb_t compact_product1 = compact_ws( product2, product3 );
wb_t compact_product2 = compact_ws( product4, product5 );
Expand Down
32 changes: 32 additions & 0 deletions src/util/simd/fd_avx_wh.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ wh_insert_variable( wh_t a, int n, ushort v ) {
#define wh_mulhi(a,b) _mm256_mulhi_epu16( (a), (b) ) /* [ (a0*b0)>>16 (a1*b1)>>16 ... (a15*b15)>>16 ] */
#define wh_mul(a,b) wh_mullo((a),(b))

/* Binary operations */

/* Note: wh_shl/wh_shr is an unsigned left/right shift by imm bits; imm
must be a compile time constant in [0,15]. The variable variants are
slower but do not require the shift amount to be known at compile
time (should still be in [0,15]). */

#define wh_not(a) _mm256_xor_si256( _mm256_set1_epi16( -1 ), (a) ) /* [ ~a0 ~a1 ... ~a15 ] */

#define wh_shl(a,imm) _mm256_slli_epi16( (a), (imm) ) /* [ a0<<imm a1<<imm ... a15<<imm ] */
#define wh_shr(a,imm) _mm256_srli_epi16( (a), (imm) ) /* [ a0>>imm a1>>imm ... a15>>imm ] */

#define wh_shl_variable(a,n) _mm256_sll_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) )
#define wh_shr_variable(a,n) _mm256_srl_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) )

#define wh_shl_vector(a,b) _mm256_sllv_epi16( (a), (b) ) /* [ a0<<b0 a1<<b1 ... a15<<b15 ] */
#define wh_shr_vector(a,b) _mm256_srlv_epi16( (a), (b) ) /* [ a0>>b0 a1>>b1 ... a15>>b15 ] */

#define wh_and(a,b) _mm256_and_si256( (a), (b) ) /* [ a0 &b0 a1& b1 ... a15& b15 ] */
#define wh_andnot(a,b) _mm256_andnot_si256( (a), (b) ) /* [ (~a0)&b0 (~a1)&b1 ... (~a15)&b15 ] */
#define wh_or(a,b) _mm256_or_si256( (a), (b) ) /* [ a0 |b0 a1 |b1 ... a15 |b15 ] */
#define wh_xor(a,b) _mm256_xor_si256( (a), (b) ) /* [ a0 ^b0 a1 ^b1 ... a15 ^b15 ] */

/* wh_rol(x,n) returns wh( rotate_left (x0,n), rotate_left (x1,n), ... )
wh_ror(x,n) returns wh( rotate_right(x0,n), rotate_right(x1,n), ... ) */

static inline wh_t wh_rol( wh_t a, int imm ) { return wh_or( wh_shl( a, imm & 15 ), wh_shr( a, (-imm) & 15 ) ); }
static inline wh_t wh_ror( wh_t a, int imm ) { return wh_or( wh_shr( a, imm & 15 ), wh_shl( a, (-imm) & 15 ) ); }

static inline wh_t wh_rol_variable( wh_t a, int n ) { return wh_or( wh_shl_variable( a, n&15 ), wh_shr_variable( a, (-n)&15 ) ); }
static inline wh_t wh_ror_variable( wh_t a, int n ) { return wh_or( wh_shr_variable( a, n&15 ), wh_shl_variable( a, (-n)&15 ) ); }

/* Logical operations */

#define wh_eq(a,b) _mm256_cmpeq_epi16( (a), (b) ) /* [ a0==b0 a1==b1 ... a15==b15 ] */
Expand Down
35 changes: 35 additions & 0 deletions src/util/simd/fd_avx_ws.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,41 @@ ws_insert_variable( ws_t a, int n, short v ) {
#define ws_mulhi(a,b) _mm256_mulhi_epi16( (a), (b) ) /* [ (a0*b0)>>16 (a1*b1)>>16 ... (a15*b15)>>16 ] */
#define ws_mul(a,b) ws_mullo((a),(b))

/* Binary operations */

/* Note: ws_hl/ws_shr/ws_shru is a left/signed right/unsigned right
shift by imm bits; imm must be a compile tiem constant in [0,15].
The variable variants are slower but do not require the shift amount
to be known at compile time (should still be in [0,15]). */

#define ws_not(a) _mm256_xor_si256( _mm256_set1_epi16( -1 ), (a) ) /* [ ~a0 ~a1 ... ~a7 ] */

#define ws_shl(a,imm) _mm256_slli_epi16( (a), (imm) ) /* [ a0<<imm a1<<imm ... a7<<imm ] */
#define ws_shr(a,imm) _mm256_srai_epi16( (a), (imm) ) /* [ a0>>imm a1>>imm ... a7>>imm ] (treat a as signed)*/
#define ws_shru(a,imm) _mm256_srli_epi16( (a), (imm) ) /* [ a0>>imm a1>>imm ... a7>>imm ] (treat a as unsigned) */

#define ws_shl_variable(a,n) _mm256_sll_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) )
#define ws_shr_variable(a,n) _mm256_sra_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) )
#define ws_shru_variable(a,n) _mm256_srl_epi16( (a), _mm_insert_epi64( _mm_setzero_si128(), (n), 0 ) )

#define ws_shl_vector(a,b) _mm256_sllv_epi16( (a), (b) ) /* [ a0<<b0 a1<<b1 ... a7<<b7 ] */
#define ws_shr_vector(a,b) _mm256_srav_epi16( (a), (b) ) /* [ a0>>b0 a1>>b1 ... a7>>b7 ] (treat a as signed) */
#define ws_shru_vector(a,b) _mm256_srlv_epi16( (a), (b) ) /* [ a0>>b0 a1>>b1 ... a7>>b7 ] (treat a as unsigned) */

#define ws_and(a,b) _mm256_and_si256( (a), (b) ) /* [ a0 &b0 a1& b1 ... a7& b7 ] */
#define ws_andnot(a,b) _mm256_andnot_si256( (a), (b) ) /* [ (~a0)&b0 (~a1)&b1 ... (~a7)&b7 ] */
#define ws_or(a,b) _mm256_or_si256( (a), (b) ) /* [ a0 |b0 a1 |b1 ... a7 |b7 ] */
#define ws_xor(a,b) _mm256_xor_si256( (a), (b) ) /* [ a0 ^b0 a1 ^b1 ... a7 ^b7 ] */

/* ws_rol(x,n) returns ws( rotate_left (x0,n), rotate_left (x1,n), ... )
ws_ror(x,n) returns ws( rotate_right(x0,n), rotate_right(x1,n), ... ) */

static inline ws_t ws_rol( ws_t a, int imm ) { return ws_or( ws_shl( a, imm & 15 ), ws_shru( a, (-imm) & 15 ) ); }
static inline ws_t ws_ror( ws_t a, int imm ) { return ws_or( ws_shru( a, imm & 15 ), ws_shl( a, (-imm) & 15 ) ); }

static inline ws_t ws_rol_variable( ws_t a, int n ) { return ws_or( ws_shl_variable( a, n&15 ), ws_shru_variable( a, (-n)&15 ) ); }
static inline ws_t ws_ror_variable( ws_t a, int n ) { return ws_or( ws_shru_variable( a, n&15 ), ws_shl_variable( a, (-n)&15 ) ); }

/* Logical operations */

#define ws_eq(a,b) _mm256_cmpeq_epi16( (a), (b) ) /* [ a0==b0 a1==b1 ... a15==b15 ] */
Expand Down
Loading

0 comments on commit 05093dd

Please sign in to comment.