Skip to content

Commit

Permalink
Add fd_avx_ws.h (16x short)
Browse files Browse the repository at this point in the history
Analogous to fd_avx_wh.h
  • Loading branch information
riptl authored and ripatel-fd committed Dec 6, 2024
1 parent 5687a5a commit 7e6d477
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 12 deletions.
12 changes: 0 additions & 12 deletions src/ballet/reedsol/fd_reedsol_pi.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,10 @@

#include "../../util/simd/fd_sse.h"

#define ws_t __m256i
#define ws_add(a,b) _mm256_add_epi16( (a), (b) )
#define ws_sub(a,b) _mm256_sub_epi16( (a), (b) )
#define ws_bcast(s0) _mm256_set1_epi16( (s0) )
#define ws_adjust_sign(a,b) _mm256_sign_epi16( (a), (b) ) /* scales elements in a by the sign of the corresponding element of b */
#define ws_mullo(a,b) _mm256_mullo_epi16( (a), (b) )
#define ws_mulhi(a,b) _mm256_mulhi_epu16( (a), (b) )
#define ws_shl(a,imm) _mm256_slli_epi16( (a), (imm) )
#define ws_and(a,b) _mm256_and_si256( (a), (b) )
#define ws_shru(a,imm) _mm256_srli_epi16( (a), (imm) )
#define ws_zero() _mm256_setzero_si256() /* Return [ 0 0 0 0 0 ... 0 0 ] */

FD_FN_UNUSED static inline ws_t ws_ld( short const * p ) { return _mm256_load_si256( (__m256i const *)p ); }
FD_FN_UNUSED static inline ws_t ws_ldu( short const * p ) { return _mm256_loadu_si256( (__m256i const *)p ); }
FD_FN_UNUSED static inline void ws_st( short * p, ws_t i ) { _mm256_store_si256( (__m256i *)p, i ); }
FD_FN_UNUSED static inline void ws_stu( short * p, ws_t i ) { _mm256_storeu_si256( (__m256i *)p, i ); }

static inline ws_t
ws_mod255( ws_t x ) {
Expand Down
1 change: 1 addition & 0 deletions src/util/simd/fd_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "fd_avx_wl.h" /* Vector long support */
#include "fd_avx_wv.h" /* Vector ulong support */
#include "fd_avx_wb.h" /* Vector uchar (byte) support */
#include "fd_avx_ws.h" /* Vector short support */
#include "fd_avx_wh.h" /* Vector ushort support */

#else
Expand Down
95 changes: 95 additions & 0 deletions src/util/simd/fd_avx_ws.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#ifndef HEADER_fd_src_util_simd_fd_avx_h
#error "Do not include this directly; use fd_avx.h"
#endif

/* Vector short API ***************************************************/

/* A ws_t is a vector wsere each 16-bit wsde lane holds a signed 16-bit
integer (a "short").
These mirror the other APIs as much as possible. Macros are
preferred over static inlines wsen it is possible to do it robustly
to reduce the risk of the compiler mucking it up. */

#define ws_t __m256i

/* Constructors */

/* Given the short values, return ... */

#define ws(h0, h1, h2, h3, h4, h5, h6, h7, h8, h9,h10,h11,h12,h13,h14,h15) /* [ h0 h1 ... h15 ] */ \
_mm256_setr_epi16( (short)( h0), (short)( h1), (short)( h2), (short)( h3), \
(short)( h4), (short)( h5), (short)( h6), (short)( h7), \
(short)( h8), (short)( h9), (short)(h10), (short)(h11), \
(short)(h12), (short)(h13), (short)(h14), (short)(h15) )

#define ws_bcast(h0) _mm256_set1_epi16( (short)(h0) ) /* [ h0 h0 ... h0 ] */

/* Predefined constants */

#define ws_zero() _mm256_setzero_si256() /* Return [ 0 0 ... 0 ] */
#define ws_one() _mm256_set1_epi16( 1 ) /* Return [ 1 1 ... 1 ] */

/* Memory operations */

/* ws_ld return the 16 shorts at the 32-byte aligned / 32-byte sized
location p as a vector short. ws_ldu is the same but p does not
have to be aligned. ws_st writes the vector short to the 32-byte
aligned / 32-byte sized location p as 16 shorts. ws_stu is the same
but p does not have to be aligned. In all these lane l wsll be at
p[l].
Note: gcc knows a __m256i may alias. */

static inline ws_t ws_ld( short const * p ) { return _mm256_load_si256( (__m256i const *)p ); }
static inline void ws_st( short * p, ws_t i ) { _mm256_store_si256( (__m256i *)p, i ); }

static inline ws_t ws_ldu( void const * p ) { return _mm256_loadu_si256( (__m256i const *)p ); }
static inline void ws_stu( void * p, ws_t i ) { _mm256_storeu_si256( (__m256i *)p, i ); }

/* Element operations */

/* ws_extract extracts the short in lane imm from the vector short.
ws_insert returns the vector short formed by replacing the value in
lane imm of a wsth the provided short. imm should be a compile time
constant in 0:15. ws_extract_variable and ws_insert_variable are the
slower but the lane n does not have to eb known at compile time
(should still be in 0:15).
Note: C99 TC3 allows type punning through a union. */

#define ws_extract(a,imm) ((short)_mm256_extract_epi16( (a), (imm) ))
#define ws_insert(a,imm,v) _mm256_insert_epi16( (a), (int)(v), (imm) )

static inline short
ws_extract_variable( ws_t a, int n ) {
union { __m256i m[1]; short h[16]; } t[1];
_mm256_store_si256( t->m, a );
return (short)t->h[n];
}

static inline ws_t
ws_insert_variable( ws_t a, int n, short v ) {
union { __m256i m[1]; short h[16]; } t[1];
_mm256_store_si256( t->m, a );
t->h[n] = v;
return _mm256_load_si256( t->m );
}

/* Arithmetic operations */

#define ws_neg(a) _mm256_sub_epi16( _mm256_setzero_si256(), (a) ) /* [ -a0 -a1 ... -a7 ] (twos complement handling) */
#define ws_abs(a) _mm256_abs_epi16( (a) ) /* [ |a0| |a1| ... |a7| ] (twos complement handling) */

#define ws_min(a,b) _mm256_min_epi16( (a), (b) ) /* [ min(a0,b0) min(a1,b1) ... min(a7,b7) ] */
#define ws_max(a,b) _mm256_max_epi16( (a), (b) ) /* [ max(a0,b0) max(a1,b1) ... max(a7,b7) ] */
#define ws_add(a,b) _mm256_add_epi16( (a), (b) ) /* [ a0 +b0 a1 +b1 ... a7 +b7 ] */
#define ws_sub(a,b) _mm256_sub_epi16( (a), (b) ) /* [ a0 -b0 a1 -b1 ... a7 -b7 ] */
#define ws_mullo(a,b) _mm256_mullo_epi16( (a), (b) ) /* [ a0*b0 a1*b1 ... a15*b15 ] */
#define ws_mulhi(a,b) _mm256_mulhi_epi16( (a), (b) ) /* [ (a0*b0)>>16 (a1*b1)>>16 ... (a15*b15)>>16 ] */
#define ws_mul(a,b) ws_mullo((a),(b))

/* Logical operations */

#define ws_eq(a,b) _mm256_cmpeq_epi16( (a), (b) ) /* [ a0==b0 a1==b1 ... a15==b15 ] */
#define ws_ne(a,b) _mm256_xor_si256( _mm256_set1_epi16( -1 ), _mm256_cmpeq_epi16( (a), (b) ) ) /* [ a0!=b0 a1!=b1 ... a15!=b15 ] */
60 changes: 60 additions & 0 deletions src/util/simd/test_avx_16x16.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ main( int argc,
char ** argv ) {
fd_boot( &argc, &argv );

# define srand() ((short)((fd_rng_uint( rng ) % 7U)-3U)) /* [-3,-2,-1,0,1,2,3] */
# define hrand() ((ushort)((fd_rng_uint( rng ) % 7U)-3U)) /* [65533,65534,65535,0,1,2,3] */

fd_rng_t _rng[1]; fd_rng_t * rng = fd_rng_join( fd_rng_new( _rng, 0U, 0UL ) );
Expand All @@ -24,6 +25,58 @@ main( int argc,

# define INVOKE_EXPAND( M, ... ) M( __VA_ARGS__ )

/* WS tests */

short si[ 16 ];
# define INIT_SI( EXPR ) do { for( ulong j=0UL; j<16UL; j++ ) { si[j] = (EXPR); } } while( 0 )

INIT_SI( (short)0 ); FD_TEST( ws_test( ws_zero(), si ) );
INIT_SI( (short)1 ); FD_TEST( ws_test( ws_one(), si ) );

for( int i=0; i<65536; i++ ) {

/* Constructors */

short xi[ 16 ]; for( ulong j=0UL; j<16UL; j++ ) xi[ j ] = srand();
short yi[ 16 ]; for( ulong j=0UL; j<16UL; j++ ) yi[ j ] = srand();

INIT_SI( yi[ 0 ] ); FD_TEST( ws_test( ws_bcast( yi[0] ), si ) );

ws_t x = INVOKE_EXPAND( ws, EXPAND_16_INDICES( xi, 0 ) ); FD_TEST( ws_test( x, xi ) );
ws_t y = INVOKE_EXPAND( ws, EXPAND_16_INDICES( yi, 0 ) ); FD_TEST( ws_test( y, yi ) );

/* Arithmetic operations */

INIT_SI( (short)-xi[j] ); FD_TEST( ws_test( ws_neg( x ), si ) );
INIT_SI( (short)fd_short_abs( xi[j] ) ); FD_TEST( ws_test( ws_abs( x ), si ) );
INIT_SI( fd_short_min( xi[j], yi[j] ) ); FD_TEST( ws_test( ws_min( x, y ), si ) );
INIT_SI( fd_short_max( xi[j], yi[j] ) ); FD_TEST( ws_test( ws_max( x, y ), si ) );
INIT_SI( (short)(xi[j]+yi[j]) ); FD_TEST( ws_test( ws_add( x, y ), si ) );
INIT_SI( (short)(xi[j]-yi[j]) ); FD_TEST( ws_test( ws_sub( x, y ), si ) );
INIT_SI( (short)(xi[j]*yi[j]) ); FD_TEST( ws_test( ws_mul( x, y ), si ) );
/* */ FD_TEST( ws_test( ws_mullo( x, y ), si ) );
INIT_SI( (short)((((int)xi[j])*((int)yi[j]))>>16) ); FD_TEST( ws_test( ws_mulhi( x, y ), si ) );

/* Logical operations */

/* TODO: eliminate this hack (see note in fd_avx_wc.h about
properly generalizing wc to 8/16/32/64-bit wide SIMD lanes). */

# define wc_to_ws_raw( x ) (x)

# define C(cond) ((short)(-(cond)))

INIT_SI( C(xi[j]==yi[j]) ); FD_TEST( ws_test( wc_to_ws_raw( ws_eq( x, y ) ), si ) );
INIT_SI( C(xi[j]!=yi[j]) ); FD_TEST( ws_test( wc_to_ws_raw( ws_ne( x, y ) ), si ) );

# undef C

# undef wc_to_ws_raw

}

# undef INIT_SI

/* WH tests */

ushort hj[ 16 ];
Expand Down Expand Up @@ -68,9 +121,16 @@ main( int argc,
INIT_HJ( C(xi[j]==yi[j]) ); FD_TEST( wh_test( wc_to_wh_raw( wh_eq( x, y ) ), hj ) );
INIT_HJ( C(xi[j]!=yi[j]) ); FD_TEST( wh_test( wc_to_wh_raw( wh_ne( x, y ) ), hj ) );

# undef C

# undef wc_to_wh_raw

}

# undef INIT_HJ

# undef hrand
# undef srand

FD_LOG_NOTICE(( "pass" ));
fd_halt();
Expand Down
67 changes: 67 additions & 0 deletions src/util/simd/test_avx_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,73 @@ int wv_test( wv_t v, ulong v0, ulong v1, ulong v2, ulong v3 ) {
return 1;
}

int ws_test( ws_t s, short const * si ) {
int volatile _[1];
short m[151] W_ATTR;
ws_t t;

if( ws_extract( s, 0 )!=si[ 0] ) return 0;
if( ws_extract( s, 1 )!=si[ 1] ) return 0;
if( ws_extract( s, 2 )!=si[ 2] ) return 0;
if( ws_extract( s, 3 )!=si[ 3] ) return 0;
if( ws_extract( s, 4 )!=si[ 4] ) return 0;
if( ws_extract( s, 5 )!=si[ 5] ) return 0;
if( ws_extract( s, 6 )!=si[ 6] ) return 0;
if( ws_extract( s, 7 )!=si[ 7] ) return 0;
if( ws_extract( s, 8 )!=si[ 8] ) return 0;
if( ws_extract( s, 9 )!=si[ 9] ) return 0;
if( ws_extract( s, 10 )!=si[10] ) return 0;
if( ws_extract( s, 11 )!=si[11] ) return 0;
if( ws_extract( s, 12 )!=si[12] ) return 0;
if( ws_extract( s, 13 )!=si[13] ) return 0;
if( ws_extract( s, 14 )!=si[14] ) return 0;
if( ws_extract( s, 15 )!=si[15] ) return 0;

for( int j=0; j<16; j++ ) { _[0]=j; if( ws_extract_variable( s, _[0] )!=si[j] ) return 0; }

ws_st( m, s ); /* Aligned store to aligned */
ws_stu( m+16, s ); /* Unaligned store to aligned */
ws_stu( m+33, s ); /* Unaligned store to aligned+1 */
ws_stu( m+50, s ); /* Unaligned store to aligned+2 */
ws_stu( m+67, s ); /* Unaligend store to aligned+3 */
ws_stu( m+84, s ); /* Unaligned store to aligned+4 */
ws_stu( m+101, s ); /* Unaligned store to aligned+5 */
ws_stu( m+118, s ); /* Unaligned store to aligned+6 */
ws_stu( m+135, s ); /* Unaligned store to aligned+7 */

t = ws_ld( m ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+33 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+50 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+67 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+84 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+101 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+118 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;
t = ws_ldu( m+135 ); if( _mm256_movemask_epi8( ws_eq( s, t ) )!=-1 ) return 0;

t = ws_insert( ws_zero(), 0, si[ 0] );
t = ws_insert( t, 1, si[ 1] );
t = ws_insert( t, 2, si[ 2] );
t = ws_insert( t, 3, si[ 3] );
t = ws_insert( t, 4, si[ 4] );
t = ws_insert( t, 5, si[ 5] );
t = ws_insert( t, 6, si[ 6] );
t = ws_insert( t, 7, si[ 7] );
t = ws_insert( t, 8, si[ 8] );
t = ws_insert( t, 9, si[ 9] );
t = ws_insert( t, 10, si[10] );
t = ws_insert( t, 11, si[11] );
t = ws_insert( t, 12, si[12] );
t = ws_insert( t, 13, si[13] );
t = ws_insert( t, 14, si[14] );
t = ws_insert( t, 15, si[15] ); if( _mm256_movemask_epi8( ws_ne( s, t ) ) ) return 0;

t = ws_zero();
for( int j=0; j<16; j++ ) { _[0]=j; t=ws_insert_variable( t, _[0], si[j] ); }
if( _mm256_movemask_epi8( ws_ne( s, t ) ) ) return 0;

return 1;
}

int wh_test( wh_t h, ushort const * hj ) {
int volatile _[1];
ushort m[151] W_ATTR;
Expand Down

0 comments on commit 7e6d477

Please sign in to comment.