Skip to content

Commit

Permalink
Rewrite scal (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Oct 6, 2023
1 parent 9c7d49c commit 2cd26cd
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 150 deletions.
40 changes: 40 additions & 0 deletions src/common/machines/avx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,44 @@ def mm256_prefix_add_pd(
out[i] = x[i] + y[i]


@instr(
"""
{{
__m256i indices = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi32({bound});
__m256i cmp = _mm256_cmpgt_epi32(prefix, indices);
{dst_data} = _mm256_blendv_ps ({dst_data}, _mm256_setzero_ps(), _mm256_castsi256_ps(cmp));
}}
"""
)
def mm256_prefix_setzero_ps(dst: [f32][8] @ AVX2, bound: size):
assert stride(dst, 0) == 1
assert bound <= 8

for i in seq(0, 8):
if i < bound:
dst[i] = 0.0


@instr(
"""
{{
__m256i indices = _mm256_set_epi64x(3, 2, 1, 0);
__m256i prefix = _mm256_set1_epi64x({bound});
__m256i cmp = _mm256_cmpgt_epi64(prefix, indices);
{dst_data} = _mm256_blendv_pd ({dst_data}, _mm256_setzero_pd(), _mm256_castsi256_pd(cmp));
}}
"""
)
def mm256_prefix_setzero_pd(dst: [f64][4] @ AVX2, bound: size):
assert stride(dst, 0) == 1
assert bound <= 4

for i in seq(0, 4):
if i < bound:
dst[i] = 0.0


Machine = MachineParameters(
name="avx2",
mem_type=AVX2,
Expand All @@ -552,6 +590,7 @@ def mm256_prefix_add_pd(
fmadd_instr_f32=mm256_fmadd_ps,
prefix_fmadd_instr_f32=mm256_prefix_fmadd_ps,
set_zero_instr_f32=mm256_setzero_ps,
prefix_set_zero_instr_f32=mm256_prefix_setzero_ps,
assoc_reduce_add_instr_f32=avx2_assoc_reduce_add_ps,
assoc_reduce_add_f32_buffer=avx2_assoc_reduce_add_ps_buffer,
mul_instr_f32=mm256_mul_ps,
Expand Down Expand Up @@ -579,6 +618,7 @@ def mm256_prefix_add_pd(
fmadd_instr_f64=mm256_fmadd_pd,
prefix_fmadd_instr_f64=mm256_prefix_fmadd_pd,
set_zero_instr_f64=mm256_setzero_pd,
prefix_set_zero_instr_f64=mm256_prefix_setzero_pd,
assoc_reduce_add_instr_f64=avx2_assoc_reduce_add_pd,
mul_instr_f64=mm256_mul_pd,
prefix_mul_instr_f64=mm256_prefix_mul_pd,
Expand Down
2 changes: 2 additions & 0 deletions src/common/machines/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MachineParameters:
fmadd_instr_f32: Any
prefix_fmadd_instr_f32: Any
set_zero_instr_f32: Any
prefix_set_zero_instr_f32: Any
assoc_reduce_add_instr_f32: Any
mul_instr_f32: Any
prefix_mul_instr_f32: Any
Expand Down Expand Up @@ -65,6 +66,7 @@ class MachineParameters:
fmadd_instr_f64: Any
prefix_fmadd_instr_f64: Any
set_zero_instr_f64: Any
prefix_set_zero_instr_f64: Any
assoc_reduce_add_instr_f64: Any
mul_instr_f64: Any
prefix_mul_instr_f64: Any
Expand Down
2 changes: 2 additions & 0 deletions src/common/machines/neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon):
fmadd_instr_f32=neon_vfmadd_4xf32_4xf32,
prefix_fmadd_instr_f32=None,
set_zero_instr_f32=neon_zero_4xf32,
prefix_set_zero_instr_f32=None,
assoc_reduce_add_instr_f32=neon_assoc_reduce_add_instr_4xf32,
mul_instr_f32=neon_vmul_4xf32,
prefix_mul_instr_f32=None,
Expand Down Expand Up @@ -128,6 +129,7 @@ def neon_vst_2xf64_backwards(dst: [f64][2] @ DRAM, src: [f64][2] @ Neon):
fmadd_instr_f64=neon_vfmadd_2xf64_2xf64,
prefix_fmadd_instr_f64=None,
set_zero_instr_f64=neon_zero_2xf64,
prefix_set_zero_instr_f64=None,
assoc_reduce_add_instr_f64=neon_assoc_reduce_add_instr_2xf64,
assoc_reduce_add_f64_buffer=neon_assoc_reduce_add_instr_2xf64_buffer,
mul_instr_f64=neon_vmul_2xf64,
Expand Down
193 changes: 43 additions & 150 deletions src/level1/scal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

import exo_blas_config as C
from composed_schedules import (
vectorize_to_loops,
interleave_execution,
hoist_stmt,
apply_to_block,
hoist_stmt,
)

from blas_composed_schedules import blas_vectorize
from codegen_helpers import (
generate_stride_any_proc,
export_exo_proc,
generate_stride_1_proc,
)
from parameters import Level_1_Params

### EXO_LOC ALGORITHM START ###
@proc
Expand All @@ -33,156 +37,45 @@ def scal_template_alpha_0(n: size, x: [R][n]):


### EXO_LOC SCHEDULE START ###
def specialize_scal(precision, alpha):
prefix = "s" if precision == "f32" else "d"
specialized_scal = scal_template if alpha != 0 else scal_template_alpha_0
specialized_scal_name = specialized_scal.name()
specialized_scal_name = specialized_scal_name.replace("_template", "")
specialized_scal = rename(specialized_scal, "exo_" + prefix + specialized_scal_name)

args = ["x"]
if alpha != 0:
args.append("alpha")

for arg in args:
specialized_scal = set_precision(specialized_scal, arg, precision)
return specialized_scal


def schedule_scal_stride_1(
VEC_W, INTERLEAVE_FACTOR, memory, instructions, precision, alpha
):
simple_stride_1 = specialize_scal(precision, alpha)
simple_stride_1 = rename(simple_stride_1, simple_stride_1.name() + "_stride_1")
simple_stride_1 = simple_stride_1.add_assertion("stride(x, 0) == 1")

main_loop = simple_stride_1.find_loop("i")
simple_stride_1 = vectorize_to_loops(
simple_stride_1, main_loop, VEC_W, memory, precision
def schedule_scal_stride_1(scal, params):
scal = generate_stride_1_proc(scal, params.precision)
main_loop = scal.find_loop("i")
scal = blas_vectorize(scal, main_loop, params)
main_loop = scal.find_loop("ioo")
scal = add_unsafe_guard(
scal,
main_loop.as_block(),
FormattedExprStr("_ < _", main_loop.lo(), main_loop.hi()),
)
simple_stride_1 = interleave_execution(
simple_stride_1, simple_stride_1.find_loop("io"), INTERLEAVE_FACTOR
main_loop = scal.find_loop("ioo")
scal = apply_to_block(scal, main_loop.body(), hoist_stmt)
middle_loop = scal.find_loop("ioi")
scal = add_unsafe_guard(
scal,
middle_loop.as_block(),
FormattedExprStr("_ < _", middle_loop.lo(), middle_loop.hi()),
)
simple_stride_1 = apply_to_block(
simple_stride_1, simple_stride_1.find_loop("ioo").body(), hoist_stmt
)
simple_stride_1 = replace_all(simple_stride_1, instructions)
return simple_stride_1


#################################################
# Generate specialized kernels for f32 precision
#################################################
middle_loop = scal.find_loop("ioi")
scal = apply_to_block(scal, middle_loop.body(), hoist_stmt)
return simplify(scal)

INTERLEAVE_FACTOR = C.Machine.vec_units * 2

exo_sscal_stride_any = specialize_scal("f32", None)
exo_sscal_stride_any = rename(
exo_sscal_stride_any, exo_sscal_stride_any.name() + "_stride_any"
)
exo_sscal_alpha_0_stride_any = specialize_scal("f32", 0)
exo_sscal_alpha_0_stride_any = rename(
exo_sscal_alpha_0_stride_any, exo_sscal_alpha_0_stride_any.name() + "_stride_any"
)

f32_instructions = [
C.Machine.load_instr_f32,
C.Machine.store_instr_f32,
C.Machine.mul_instr_f32,
C.Machine.broadcast_scalar_instr_f32,
C.Machine.set_zero_instr_f32,
template_sched_list = [
(scal_template, schedule_scal_stride_1),
(scal_template_alpha_0, schedule_scal_stride_1),
]

if None not in f32_instructions:
exo_sscal_stride_1 = schedule_scal_stride_1(
C.Machine.vec_width,
INTERLEAVE_FACTOR,
C.Machine.mem_type,
f32_instructions,
"f32",
None,
)
exo_sscal_alpha_0_stride_1 = schedule_scal_stride_1(
C.Machine.vec_width,
INTERLEAVE_FACTOR,
C.Machine.mem_type,
f32_instructions,
"f32",
0,
)
else:
exo_sscal_stride_1 = specialize_scal("f32", None)
exo_sscal_stride_1 = rename(
exo_sscal_stride_1, exo_sscal_stride_1.name() + "_stride_1"
)
exo_sscal_alpha_0_stride_1 = specialize_scal("f32", 0)
exo_sscal_alpha_0_stride_1 = rename(
exo_sscal_alpha_0_stride_1, exo_sscal_alpha_0_stride_1.name() + "_stride_1"
)

#################################################
# Generate specialized kernels for f64 precision
#################################################

exo_dscal_stride_any = specialize_scal("f64", None)
exo_dscal_stride_any = rename(
exo_dscal_stride_any, exo_dscal_stride_any.name() + "_stride_any"
)
exo_dscal_alpha_0_stride_any = specialize_scal("f64", 0)
exo_dscal_alpha_0_stride_any = rename(
exo_dscal_alpha_0_stride_any, exo_dscal_alpha_0_stride_any.name() + "_stride_any"
)

f64_instructions = [
C.Machine.load_instr_f64,
C.Machine.store_instr_f64,
C.Machine.mul_instr_f64,
C.Machine.broadcast_scalar_instr_f64,
C.Machine.set_zero_instr_f64,
]

if None not in f64_instructions:
exo_dscal_stride_1 = schedule_scal_stride_1(
C.Machine.vec_width // 2,
INTERLEAVE_FACTOR,
C.Machine.mem_type,
f64_instructions,
"f64",
None,
)
exo_dscal_alpha_0_stride_1 = schedule_scal_stride_1(
C.Machine.vec_width // 2,
INTERLEAVE_FACTOR,
C.Machine.mem_type,
f64_instructions,
"f64",
0,
)
else:
exo_dscal_stride_1 = specialize_scal("f64", None)
exo_dscal_stride_1 = rename(
exo_dscal_stride_1, exo_dscal_stride_1.name() + "_stride_1"
)
exo_dscal_alpha_0_stride_1 = specialize_scal("f64", 0)
exo_dscal_alpha_0_stride_1 = rename(
exo_dscal_alpha_0_stride_1, exo_dscal_alpha_0_stride_1.name() + "_stride_1"
)
# TODO: Debug alpha zero case

for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(
template,
Level_1_Params(
precision=precision, accumulators_count=1, interleave_factor=4
),
)
export_exo_proc(globals(), proc_stride_1)
### EXO_LOC SCHEDULE END ###

entry_points = [
exo_sscal_stride_any,
exo_sscal_stride_1,
exo_sscal_alpha_0_stride_1,
exo_sscal_alpha_0_stride_any,
exo_dscal_stride_any,
exo_dscal_stride_1,
exo_dscal_alpha_0_stride_1,
exo_dscal_alpha_0_stride_any,
]


if __name__ == "__main__":
for p in entry_points:
print(p)

__all__ = [p.name() for p in entry_points]

0 comments on commit 2cd26cd

Please sign in to comment.