Skip to content

Commit

Permalink
Automation of precision-stride variants generation (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Feb 8, 2024
1 parent 68dd974 commit 4ebdbdb
Show file tree
Hide file tree
Showing 25 changed files with 224 additions and 504 deletions.
63 changes: 46 additions & 17 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,53 @@
from exo.API_cursors import *

from introspection import *
from higher_order import *
import exo_blas_config as C


def specialize_precision(template_proc, precision):
def specialize_precision(proc, precision, all_buffs=True):
assert precision in {"f32", "f64"}
prefix = "s" if precision == "f32" else "d"
template_name = template_proc.name()
template_name = template_name.replace("_template", "")
specialized_proc = rename(template_proc, "exo_" + prefix + template_name)
template_name = proc.name()
proc = rename(proc, prefix + template_name)

for arg in template_proc.args():
if arg.type().is_numeric():
specialized_proc = set_precision(specialized_proc, arg, precision)
def has_type_R(proc, s, *arg):
if not isinstance(s, (AllocCursor, ArgCursor)):
return False
return s.type() == ExoType.R

for stmt in lrn_stmts(template_proc):
if isinstance(stmt, AllocCursor):
specialized_proc = set_precision(specialized_proc, stmt, precision)
return specialized_proc
set_R_type = predicate(set_precision, has_type_R)

def is_numeric(proc, s, *arg):
if not isinstance(s, (AllocCursor, ArgCursor)):
return False
return s.type().is_numeric()

def generate_stride_any_proc(template_proc, precision):
proc = specialize_precision(template_proc, precision)
set_numerics = predicate(set_precision, is_numeric)

set_type = set_numerics if all_buffs else set_R_type
proc = apply(set_type)(proc, proc.args(), precision)
proc = make_pass(set_type)(proc, proc.body(), precision)
return proc


def generate_stride_any_proc(proc):
proc = rename(proc, proc.name() + "_stride_any")
proc = stage_scalar_args(proc)
return proc


def generate_stride_1_proc(template_proc, precision):
proc = specialize_precision(template_proc, precision)
def generate_stride_1_proc(proc):
proc = rename(proc, proc.name() + "_stride_1")
for arg in proc.args():
if arg.is_tensor():
proc = proc.add_assertion(
f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1"
)
proc = stage_scalar_args(proc)
return proc


def export_exo_proc(globals, proc):
proc = rename(proc, f"exo_{proc.name()}")
globals[proc.name()] = simplify(proc)
globals.setdefault("__all__", []).append(proc.name())

Expand All @@ -70,3 +79,23 @@ def stage_scalar_args(proc):
if arg.type().is_numeric() and not arg.is_tensor():
proc = stage_mem(proc, proc.body(), arg.name(), f"{arg.name()}_")
return proc


def variants_generator(blas_op):
def generate(proc, loop_name, *args, globals=None):
for precision in ("f32", "f64"):
proc_variant = specialize_precision(proc, precision)

proc_variant = stage_scalar_args(proc_variant)

stride_any = generate_stride_any_proc(proc_variant)
stride_any = bind_builtins_args(stride_any, stride_any.body(), precision)
export_exo_proc(globals, stride_any)

stride_1 = generate_stride_1_proc(proc_variant)
loop = stride_1.find_loop(loop_name)
stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args)
stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision)
export_exo_proc(globals, stride_1)

return generate
36 changes: 5 additions & 31 deletions src/level1/asum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *

import exo_blas_config as C
from composed_schedules import *
from blaslib import *
from codegen_helpers import *

Expand All @@ -22,32 +17,11 @@ def asum(n: size, x: [f32][n] @ DRAM, result: f32 @ DRAM):
### EXO_LOC ALGORITHM END ###

### EXO_LOC SCHEDULE START ###
def schedule_asum_stride_1(asum, precision):
asum = generate_stride_1_proc(asum, precision)
if C.Machine.mem_type is not AVX2:
def schedule_asum(asum, loop, precision, machine, interleave_factor):
if machine.mem_type is not AVX2:
return asum
asum = optimize_level_1(asum, asum.find_loop("i"), precision, C.Machine, 7)
return asum


template_sched_list = [
(asum, schedule_asum_stride_1),
]

for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
proc_stride_any = bind_builtins_args(
proc_stride_any, proc_stride_any.body(), precision
)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(
template,
precision,
)
proc_stride_1 = bind_builtins_args(
proc_stride_1, proc_stride_1.body(), precision
)
export_exo_proc(globals(), proc_stride_1)
return optimize_level_1(asum, loop, precision, machine, interleave_factor)


variants_generator(schedule_asum)(asum, "i", 7, globals=globals())
### EXO_LOC SCHEDULE END ###
33 changes: 5 additions & 28 deletions src/level1/axpy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *
from exo.API_cursors import *

import exo_blas_config as C
from composed_schedules import *

from blaslib import *
from codegen_helpers import *

### EXO_LOC ALGORITHM START ###
@proc
def axpy_template(n: size, alpha: R, x: [R][n], y: [R][n]):
def axpy(n: size, alpha: R, x: [R][n], y: [R][n]):
for i in seq(0, n):
y[i] += alpha * x[i]


@proc
def axpy_template_alpha_1(n: size, x: [R][n], y: [R][n]):
def axpy_alpha_1(n: size, x: [R][n], y: [R][n]):
for i in seq(0, n):
y[i] += x[i]

Expand All @@ -29,22 +22,6 @@ def axpy_template_alpha_1(n: size, x: [R][n], y: [R][n]):


### EXO_LOC SCHEDULE START ###
def schedule_axpy_stride_1(axpy, precision):
axpy = generate_stride_1_proc(axpy, precision)
main_loop = axpy.find_loop("i")
axpy = optimize_level_1(axpy, main_loop, precision, C.Machine, 4)
return simplify(axpy)


template_sched_list = [
(axpy_template, schedule_axpy_stride_1),
(axpy_template_alpha_1, schedule_axpy_stride_1),
]

for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(template, precision)
export_exo_proc(globals(), proc_stride_1)
for proc in axpy, axpy_alpha_1:
variants_generator(optimize_level_1)(proc, "i", 4, globals=globals())
### EXO_LOC SCHEDULE END ###
29 changes: 3 additions & 26 deletions src/level1/dot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *
import exo.API_cursors as pc

import exo_blas_config as C
from composed_schedules import *

from blaslib import *
from codegen_helpers import *


### EXO_LOC ALGORITHM START ###
@proc
def dot_template(n: size, x: [R][n], y: [R][n], result: R):
def dot(n: size, x: [R][n], y: [R][n], result: R):
result = 0.0
for i in seq(0, n):
result += x[i] * y[i]
Expand All @@ -25,21 +18,5 @@ def dot_template(n: size, x: [R][n], y: [R][n], result: R):


### EXO_LOC SCHEDULE START ###
def schedule_dot_stride_1(dot, precision):
dot = generate_stride_1_proc(dot, precision)
main_loop = dot.find_loop("i")
dot = optimize_level_1(dot, main_loop, precision, C.Machine, 4)
return simplify(dot)


template_sched_list = [
(dot_template, schedule_dot_stride_1),
]

for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(template, precision)
export_exo_proc(globals(), proc_stride_1)
variants_generator(optimize_level_1)(dot, "i", 4, globals=globals())
### EXO_LOC SCHEDULE END ###
40 changes: 10 additions & 30 deletions src/level1/dsdot.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.platforms.neon import *
from exo.syntax import *
from exo.stdlib.scheduling import *

import exo_blas_config as C
from composed_schedules import *
from blaslib import *
from codegen_helpers import *
import exo_blas_config as C

### EXO_LOC ALGORITHM START ###
@proc
def dsdot_template(n: size, x: [f32][n], y: [f32][n], result: f64):
def dsdot(n: size, x: [f32][n], y: [f32][n], result: f64):
d_dot: f64
d_dot = 0.0
for i in seq(0, n):
Expand All @@ -27,7 +22,7 @@ def dsdot_template(n: size, x: [f32][n], y: [f32][n], result: f64):


@proc
def sdsdot_template(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32):
def sdsdot(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32):
d_result: f64
d_dot: f64
d_dot = 0.0
Expand All @@ -46,26 +41,11 @@ def sdsdot_template(n: size, sb: f32, x: [f32][n], y: [f32][n], result: f32):


### EXO_LOC SCHEDULE START ###
def schedule_dsdot_stride_1(proc, name):
proc = rename(proc, name)
proc = proc.add_assertion("stride(x, 0) == 1")
proc = proc.add_assertion("stride(y, 0) == 1")

if C.Machine.mem_type is not AVX2:
return proc
proc = optimize_level_1(proc, proc.find_loop("i"), "f64", C.Machine, 4)
return proc


export_exo_proc(globals(), rename(dsdot_template, "exo_dsdot_stride_any"))
export_exo_proc(globals(), rename(sdsdot_template, "exo_sdsdot_stride_any"))
export_exo_proc(
globals(),
schedule_dsdot_stride_1(dsdot_template, "exo_dsdot_stride_1"),
)
export_exo_proc(
globals(),
schedule_dsdot_stride_1(sdsdot_template, "exo_sdsdot_stride_1"),
)

for proc in dsdot, sdsdot:
export_exo_proc(globals(), generate_stride_any_proc(proc))
stride_1 = generate_stride_1_proc(proc)
if C.Machine.mem_type is AVX2:
loop = stride_1.find_loop("i")
stride_1 = optimize_level_1(stride_1, loop, "f64", C.Machine, 4)
export_exo_proc(globals(), stride_1)
### EXO_LOC SCHEDULE END ###
26 changes: 2 additions & 24 deletions src/level1/exo_copy.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,20 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *

from blaslib import *
from codegen_helpers import *
import exo_blas_config as C


### EXO_LOC ALGORITHM START ###
@proc
def copy_template(n: size, x: [R][n], y: [R][n]):
def copy(n: size, x: [R][n], y: [R][n]):
for i in seq(0, n):
y[i] = x[i]


### EXO_LOC ALGORITHM END ###

### EXO_LOC SCHEDULE START ###
def schedule_copy(exo_copy, precision):
exo_copy = generate_stride_1_proc(exo_copy, precision)
main_loop = exo_copy.find_loop("i")
exo_copy = optimize_level_1(exo_copy, main_loop, precision, C.Machine, 4)
return simplify(exo_copy)


template_sched_list = [
(copy_template, schedule_copy),
]

for precision in ("f32", "f64"):
for template, sched in template_sched_list:
proc_stride_any = generate_stride_any_proc(template, precision)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = sched(template, precision)
export_exo_proc(globals(), proc_stride_1)

variants_generator(optimize_level_1)(copy, "i", 4, globals=globals())
### EXO_LOC SCHEDULE END ###
4 changes: 2 additions & 2 deletions src/level1/nrm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

### EXO_LOC ALGORITHM START ###
@proc
def nrm2_template(n: size, x: [R][n], result: R):
def nrm2(n: size, x: [R][n], result: R):
result = 0.0
for i in seq(0, n):
result += x[i] * x[i]
Expand All @@ -23,7 +23,7 @@ def nrm2_template(n: size, x: [R][n], result: R):
### EXO_LOC SCHEDULE START ###
def specialize_precision(precision):
prefix = "s" if precision == "f32" else "d"
specialized_copy = rename(nrm2_template, "exo_" + prefix + "nrm2")
specialized_copy = rename(nrm2, "exo_" + prefix + "nrm2")
for arg in ["x", "result"]:
specialized_copy = set_precision(specialized_copy, arg, precision)
return specialized_copy
Expand Down
Loading

0 comments on commit 4ebdbdb

Please sign in to comment.