Skip to content

Commit

Permalink
Rewrite swap (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Oct 6, 2023
1 parent 2cd26cd commit f052c68
Showing 1 changed file with 25 additions and 65 deletions.
90 changes: 25 additions & 65 deletions src/level1/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
from exo.stdlib.scheduling import *

import exo_blas_config as C
from composed_schedules import vectorize_to_loops, interleave_execution

from blas_composed_schedules import blas_vectorize
from codegen_helpers import (
generate_stride_any_proc,
export_exo_proc,
generate_stride_1_proc,
)
from parameters import Level_1_Params

### EXO_LOC ALGORITHM START ###
@proc
Expand All @@ -24,72 +29,27 @@ def swap_template(n: size, x: [R][n], y: [R][n]):


### EXO_LOC SCHEDULE START ###
def specialize_precision(precision):
prefix = "s" if precision == "f32" else "d"
specialized_copy = rename(swap_template, "exo_" + prefix + "swap")
for arg in ["x", "y", "tmp"]:
specialized_copy = set_precision(specialized_copy, arg, precision)
return specialized_copy


def schedule_swap_stride_1(VEC_W, memory, instructions, precision):
simple_stride_1 = specialize_precision(precision)
simple_stride_1 = rename(simple_stride_1, simple_stride_1.name() + "_stride_1")
simple_stride_1 = simple_stride_1.add_assertion("stride(x, 0) == 1")
simple_stride_1 = simple_stride_1.add_assertion("stride(y, 0) == 1")
def schedule_swap(swap, params):
swap = generate_stride_1_proc(swap, params.precision)
main_loop = swap.find_loop("i")
swap = blas_vectorize(swap, main_loop, params)
return simplify(swap)

simple_stride_1 = vectorize_to_loops(
simple_stride_1, simple_stride_1.find_loop("i"), VEC_W, memory, precision
)
simple_stride_1 = replace_all(simple_stride_1, instructions)

return simple_stride_1


exo_sswap_stride_any = specialize_precision("f32")
exo_sswap_stride_any = rename(
exo_sswap_stride_any, exo_sswap_stride_any.name() + "_stride_any"
)

f32_instructions = [
C.Machine.load_instr_f32,
C.Machine.store_instr_f32,
C.Machine.reg_copy_instr_f32,
template_sched_list = [
(swap_template, schedule_swap),
]
exo_sswap_stride_1 = schedule_swap_stride_1(
C.Machine.vec_width, C.Machine.mem_type, f32_instructions, "f32"
)

exo_dswap_stride_any = specialize_precision("f64")
exo_dswap_stride_any = rename(
exo_dswap_stride_any, exo_dswap_stride_any.name() + "_stride_any"
)
for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(
template,
Level_1_Params(
precision=precision, accumulators_count=1, interleave_factor=4
),
)
export_exo_proc(globals(), proc_stride_1)

f64_instructions = [
C.Machine.load_instr_f64,
C.Machine.store_instr_f64,
C.Machine.reg_copy_instr_f64,
]
if None not in f64_instructions:
exo_dswap_stride_1 = schedule_swap_stride_1(
C.Machine.vec_width // 2, C.Machine.mem_type, f64_instructions, "f64"
)
else:
exo_dswap_stride_1 = specialize_precision("f64")
exo_dswap_stride_1 = rename(
exo_dswap_stride_1, exo_dswap_stride_1.name() + "_stride_1"
)
### EXO_LOC SCHEDULE END ###

entry_points = [
exo_sswap_stride_any,
exo_sswap_stride_1,
exo_dswap_stride_any,
exo_dswap_stride_1,
]

if __name__ == "__main__":
for p in entry_points:
print(p)

__all__ = [p.name() for p in entry_points]

0 comments on commit f052c68

Please sign in to comment.