Skip to content

Commit

Permalink
Rewrite syr2 (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Feb 16, 2024
1 parent a9c971e commit 9352184
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 210 deletions.
28 changes: 13 additions & 15 deletions src/common/blaslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,21 @@ def optimize_level_1(proc, loop, precision, machine, interleave_factor):
instructions = machine.get_instructions(precision)

loop = proc.forward(loop)

proc = cse(proc, loop.body())
proc = cse(proc, loop.body(), precision)

# Vectorization
vectorize_tail = memory in {AVX2}
tail = "predicate" if vectorize_tail else "cut"
tail = "cut_and_predicate" if vectorize_tail else "cut"
proc, (loop,) = vectorize(
proc, loop, vec_width, precision, memory, tail=tail, rc=True
)

# Hoist any stmt
proc, (_, loop) = hoist_from_loop(proc, loop, rc=True)

if vectorize_tail:
proc = cut_tail_and_unguard(proc, loop)
proc, (_, loop) = hoist_from_loop(proc, loop, rc=True)
proc = interleave_loop(
proc, loop, interleave_factor, par_reduce=True, memory=memory
)

if interleave_factor > 1:
proc = interleave_loop(
proc, loop, interleave_factor, par_reduce=True, memory=memory
)

proc = cleanup(proc)
proc = replace_all_stmts(proc, instructions)
Expand All @@ -64,20 +58,24 @@ def get_triangle_type(proc, loop):
return 0


def optimize_level_2(proc, outer_loop, precision, machine, rows_factor, cols_factor):
def optimize_level_2(
proc, outer_loop, precision, machine, rows_factor, cols_factor, round_up=None
):
vec_width = machine.vec_width(precision)
memory = machine.mem_type
inner_loop = get_inner_loop(proc, outer_loop)

if triangle := get_triangle_type(proc, inner_loop):
round_up = memory in {AVX2}
if round_up is None:
round_up = memory in {AVX2}
rows_factor = min(rows_factor, vec_width)
if round_up and triangle == 1:
proc, (outer_loop,) = cut_loop_and_unroll(proc, outer_loop, 1, rc=True)
inner_loop = get_inner_loop(proc, outer_loop)
if not round_up and triangle == 2:
proc, (inner_loop,) = cut_loop_and_unroll(proc, inner_loop, 1, rc=True)

proc, (inner_loop,) = cut_loop_and_unroll(
proc, inner_loop, 1, front=False, rc=True
)
proc = round_loop(proc, inner_loop, vec_width, up=round_up)

proc = parallelize_all_reductions(proc, inner_loop, 1, unroll=True)
Expand Down
6 changes: 3 additions & 3 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def stage_scalar_args(proc):
return proc


def identity_schedule(proc, *args):
def identity_schedule(proc, *args, **kwargs):
return proc


def variants_generator(blas_op):
def generate(proc, loop_name, *args, globals=None):
def generate(proc, loop_name, *args, globals=None, **kwargs):
for precision in ("f32", "f64"):
proc_variant = specialize_precision(proc, precision)

Expand All @@ -98,7 +98,7 @@ def generate(proc, loop_name, *args, globals=None):

stride_1 = generate_stride_1_proc(proc_variant)
loop = stride_1.find_loop(loop_name)
stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args)
stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args, **kwargs)
stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision)
export_exo_proc(globals, stride_1)

Expand Down
48 changes: 37 additions & 11 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ def interleave_loop(proc, loop, factor=None, par_reduce=False, memory=DRAM, tail
s2 x c
s3 x c
"""
if factor == 1:
return proc

loop = proc.forward(loop)

Expand Down Expand Up @@ -641,7 +643,11 @@ def jam_stmt(proc, stmt, unsafe_disable_check=False, rc=False):
raise BLAS_SchedulingError("Next statement must be a loop.")

proc = add_loop(
proc, stmt, loop.name(), FormattedExprStr("_ - _", loop.hi(), loop.lo())
proc,
stmt,
loop.name(),
FormattedExprStr("_ - _", loop.hi(), loop.lo()),
unsafe_disable_check=unsafe_disable_check,
)
stmt = proc.forward(stmt)
stmt_loop = proc.forward(stmt).parent()
Expand Down Expand Up @@ -822,8 +828,11 @@ def cut_tail_and_unguard(proc, loop):
loop = proc.forward(loop)

# Cut the tail iterations
last_outer_iteration = FormattedExprStr("_ - 1", loop.hi())
proc = cut_loop(proc, loop, last_outer_iteration)
success = True
while success:
loop = proc.forward(loop)
cut = FormattedExprStr("_ - 1", loop.hi())
proc, success = attempt(cut_loop)(proc, loop, cut, rs=True)

# Unguard
proc = dce(proc, loop)
Expand Down Expand Up @@ -879,11 +888,16 @@ def divide_and_predicate_stmts(proc, loop, factor, rc=True):
proc = simplify(proc)
proc = fission_into_singles(proc, inner.body()[0])
for (stmt, allocs) in hoisted[::-1]:
proc, (outer, _) = jam_stmt(proc, stmt, unsafe_disable_check=True, rc=True)
proc, (inner, _) = jam_stmt(proc, stmt, unsafe_disable_check=True, rc=True)
if isinstance(stmt.next(), InvalidCursor):
s = stmt.parent()
else:
s = stmt
proc, (outer, _) = jam_stmt(proc, s, unsafe_disable_check=True, rc=True)
proc, (inner, _) = jam_stmt(proc, s, unsafe_disable_check=True, rc=True)
proc = apply(sink_alloc)(proc, allocs)
proc = apply(sink_alloc)(proc, allocs)
proc = dce(proc, inner)

if not rc:
return proc
outer = proc.forward(outer)
Expand Down Expand Up @@ -916,7 +930,7 @@ def vectorize_predicate_tail(
children_ops = [fma_rule, abs_rule]
proc = stage_compute(proc, loop, precision, mem_type, children_ops)

proc, (outer, inner) = divide_and_predicate_stmts(proc, loop, vec_width, rc=True)
proc, (outer, inner, _) = auto_divide_loop(proc, loop, vec_width)
proc = simplify(proc)
proc = fission_into_singles(proc, inner)

Expand Down Expand Up @@ -1667,7 +1681,7 @@ def rewrite(proc, stmt, values):
return proc, binary_specialize_cursors(filtered_stmts)


def cse(proc, block):
def cse(proc, block, precision):
if not isinstance(block, BlockCursor):
block = proc.forward(block).as_block()

Expand All @@ -1687,8 +1701,11 @@ def cse(proc, block):
for buff, idx_map in buff_map.items():
for idx, access_list in idx_map.items():
if len(access_list) > 1:
staging_block = get_bounding_block(proc, access_list)
proc = auto_stage_mem(proc, staging_block, buff)
if all(is_read(proc, c) for c in access_list):
proc = bind_and_set_expr(proc, access_list, precision, DRAM)
else:
staging_block = get_bounding_block(proc, access_list)
proc = auto_stage_mem(proc, staging_block, buff)
return proc


Expand Down Expand Up @@ -1761,9 +1778,18 @@ def __iter__(self):
yield self.loop


def cut_loop_and_unroll(proc, loop, const, rc=False):
def cut_loop_and_unroll(proc, loop, const, front=True, rc=False):
loop = proc.forward(loop)
proc, (const_loop, loop) = cut_loop_(proc, loop, const, rc=True)
cut = (
FormattedExprStr("_ + 1", loop.lo())
if front
else FormattedExprStr("_ - 1", loop.hi())
)
proc, (const_loop, loop) = cut_loop_(proc, loop, cut, rc=True)
if not front:
const_loop, loop = loop, const_loop
proc = shift_loop(proc, const_loop, 0)
proc = simplify(proc)
proc = unroll_loop(proc, const_loop)
proc = shift_loop(proc, loop, 0)
if not rc:
Expand Down
6 changes: 5 additions & 1 deletion src/common/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,12 @@ def is_read(proc, read, name=None):
return isinstance(read, ReadCursor) and (name is None or read.name() == name)


def is_write(proc, write):
return is_reduce(proc, write) or is_assign(proc, write)


def is_access(proc, access):
return is_read(proc, access) or is_reduce(proc, access) or is_assign(proc, access)
return is_read(proc, access) or is_write(proc, access)


def is_unary_minus(proc, expr):
Expand Down
8 changes: 6 additions & 2 deletions src/level2/symv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,9 @@ def symv_rm_l(n: size, alpha: R, A: [R][n, n], x: [R][n], beta: R, y: [R][n]):
y[i] += alpha * (dot - A[i, i] * x[i])


variants_generator(optimize_level_2)(symv_rm_u, "i #1", 4, 1, globals=globals())
variants_generator(optimize_level_2)(symv_rm_l, "i #1", 4, 1, globals=globals())
variants_generator(optimize_level_2)(
symv_rm_u, "i #1", 4, 1, round_up=False, globals=globals()
)
variants_generator(optimize_level_2)(
symv_rm_l, "i #1", 4, 1, round_up=False, globals=globals()
)
145 changes: 13 additions & 132 deletions src/level2/syr2.py
Original file line number Diff line number Diff line change
@@ -1,161 +1,42 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *

import exo_blas_config as C
from blaslib import *
from codegen_helpers import *


### EXO_LOC ALGORITHM START ###
@proc
def syr2_row_major_Upper(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]):
def syr2_rm_u(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]):
assert stride(A, 1) == 1

for i in seq(0, n):
for j in seq(0, n - i):
A[i, i + j] += alpha * x[i] * y[i + j] + alpha * y[i] * x[i + j]
for j in seq(0, i + 1):
A[n - i - 1, n - j - 1] += (alpha * x[n - i - 1]) * y[n - j - 1] + (
alpha * y[n - i - 1]
) * x[n - j - 1]


@proc
def syr2_row_major_Lower(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]):
def syr2_rm_l(n: size, alpha: R, x: [R][n], y: [R][n], A: [R][n, n]):
assert stride(A, 1) == 1

for i in seq(0, n):
for j in seq(0, i + 1):
A[i, j] += alpha * x[i] * y[j] + alpha * y[i] * x[j]
A[i, j] += (alpha * x[i]) * y[j] + (alpha * y[i]) * x[j]


### EXO_LOC ALGORITHM END ###


### EXO_LOC SCHEDULE START ###
def specialize_syr2(syr2, precision):
prefix = "s" if precision == "f32" else "d"
name = syr2.name()
name = name.replace("", "")
specialized = rename(syr2, "exo_" + prefix + name)

args = ["alpha", "x", "y", "A"]

for arg in args:
specialized = set_precision(specialized, arg, precision)

return simplify(specialized)


def schedule_interleave_syr2_row_major_stride_1(
syr2, VEC_W, INTERLEAVE_FACTOR, memory, instructions, precision
):
stride_1 = specialize_syr2(syr2, precision)
stride_1 = rename(stride_1, stride_1.name() + "_stride_1")
stride_1 = stride_1.add_assertion("stride(x, 0) == 1")
stride_1 = stride_1.add_assertion("stride(y, 0) == 1")

return simplify(stride_1)


#################################################
# Kernel Parameters
#################################################

#################################################
# Generate specialized kernels for f32 precision
#################################################

exo_ssyr2_row_major_Upper_stride_any = specialize_syr2(syr2_row_major_Upper, "f32")
exo_ssyr2_row_major_Upper_stride_any = rename(
exo_ssyr2_row_major_Upper_stride_any,
exo_ssyr2_row_major_Upper_stride_any.name() + "_stride_any",
variants_generator(optimize_level_2)(
syr2_rm_u, "i", 4, 1, round_up=False, globals=globals()
)
exo_ssyr2_row_major_Lower_stride_any = specialize_syr2(syr2_row_major_Lower, "f32")
exo_ssyr2_row_major_Lower_stride_any = rename(
exo_ssyr2_row_major_Lower_stride_any,
exo_ssyr2_row_major_Lower_stride_any.name() + "_stride_any",
variants_generator(optimize_level_2)(
syr2_rm_l, "i", 4, 1, round_up=False, globals=globals()
)
f32_instructions = [
C.Machine.load_instr_f32,
C.Machine.store_instr_f32,
C.Machine.mul_instr_f32,
C.Machine.fmadd_reduce_instr_f32,
C.Machine.broadcast_instr_f32,
C.Machine.broadcast_scalar_instr_f32,
]

exo_ssyr2_row_major_Upper_stride_1 = schedule_interleave_syr2_row_major_stride_1(
syr2_row_major_Upper,
C.Machine.f32_vec_width,
1,
C.Machine.mem_type,
f32_instructions,
"f32",
)
exo_ssyr2_row_major_Lower_stride_1 = schedule_interleave_syr2_row_major_stride_1(
syr2_row_major_Lower,
C.Machine.f32_vec_width,
1,
C.Machine.mem_type,
f32_instructions,
"f32",
)

#################################################
# Generate specialized kernels for f64 precision
#################################################

exo_dsyr2_row_major_Upper_stride_any = specialize_syr2(syr2_row_major_Upper, "f64")
exo_dsyr2_row_major_Upper_stride_any = rename(
exo_dsyr2_row_major_Upper_stride_any,
exo_dsyr2_row_major_Upper_stride_any.name() + "_stride_any",
)
exo_dsyr2_row_major_Lower_stride_any = specialize_syr2(syr2_row_major_Lower, "f64")
exo_dsyr2_row_major_Lower_stride_any = rename(
exo_dsyr2_row_major_Lower_stride_any,
exo_dsyr2_row_major_Lower_stride_any.name() + "_stride_any",
)

f64_instructions = [
C.Machine.load_instr_f64,
C.Machine.store_instr_f64,
C.Machine.mul_instr_f64,
C.Machine.fmadd_reduce_instr_f64,
C.Machine.broadcast_instr_f64,
C.Machine.broadcast_scalar_instr_f64,
]

exo_dsyr2_row_major_Upper_stride_1 = schedule_interleave_syr2_row_major_stride_1(
syr2_row_major_Upper,
C.Machine.f32_vec_width // 2,
1,
C.Machine.mem_type,
f64_instructions,
"f64",
)
exo_dsyr2_row_major_Lower_stride_1 = schedule_interleave_syr2_row_major_stride_1(
syr2_row_major_Lower,
C.Machine.f32_vec_width // 2,
1,
C.Machine.mem_type,
f64_instructions,
"f64",
)
### EXO_LOC SCHEDULE END ###

entry_points = [
exo_ssyr2_row_major_Upper_stride_any,
exo_ssyr2_row_major_Upper_stride_1,
exo_dsyr2_row_major_Upper_stride_any,
exo_dsyr2_row_major_Upper_stride_1,
exo_ssyr2_row_major_Lower_stride_any,
exo_ssyr2_row_major_Lower_stride_1,
exo_dsyr2_row_major_Lower_stride_any,
exo_dsyr2_row_major_Lower_stride_1,
]

if __name__ == "__main__":
for p in entry_points:
print(p)

__all__ = [p.name() for p in entry_points]
Loading

0 comments on commit 9352184

Please sign in to comment.