Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemm macro micro #90

Merged
merged 14 commits into from
Mar 20, 2024
7 changes: 4 additions & 3 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@
"generator": "Ninja",
"binaryDir": "${sourceDir}/build/avx2",
"cacheVariables": {
"CMAKE_C_FLAGS": "-march=native -fno-tree-vectorize -fno-unroll-loops -O3 -ffast-math -std=c11",
"CMAKE_CXX_FLAGS": "-std=c++17 -march=native -fno-tree-vectorize -fno-unroll-loops -O3 -ffast-math",
"CMAKE_C_FLAGS": "-march=core-avx2 -O3 -ffast-math -std=c11",
"CMAKE_CXX_FLAGS": "-std=c++17 -march=core-avx2 -O3 -ffast-math",
"CXX_STANDARD" : "C++17",
"CMAKE_BUILD_TYPE": "Release",
"TARGET_ARCH": "avx2"
"TARGET_ARCH": "avx2",
"MKL_ENABLE_INSTRUCTIONS": "AVX2"
}
}
]
Expand Down
11 changes: 7 additions & 4 deletions analytics_tools/graphing/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def kernel_graphs_dir(kernel):


def help_msg():
print("python graph.py <kernel name>!")
print("python graph.py <kernel name> <verbose?>!")
exit(1)


def check_args():
if len(sys.argv) != 2:
if len(sys.argv) > 3:
help_msg()


Expand Down Expand Up @@ -146,7 +146,7 @@ def plot_bandwidth_throughput(kernel, data, peaks, loads=True):
plt.savefig(filename)


def plot_flops_throughput(kernel, data, peaks):
def plot_flops_throughput(kernel, data, peaks, verbose):
plt.clf()

some_point = next(iter(data.values()))[0]
Expand All @@ -163,6 +163,8 @@ def plot_flops_throughput(kernel, data, peaks):
assert len(runs) == len(set(runs)) # No duplicates
x = [run.get_size_param() for run in sorted_runs]
y = [run.get_gflops_per_sec() for run in sorted_runs]
if verbose:
print(libname, ": ", list(zip(x, y)))
plt.plot(x, y, label=libname)

plt.legend()
Expand All @@ -185,6 +187,7 @@ def plot_flops_throughput(kernel, data, peaks):
check_args()

kernel = sys.argv[1]
verbose = sys.argv[2]

init_directories(kernel)
jsons = get_jsons(kernel)
Expand All @@ -196,4 +199,4 @@ def plot_flops_throughput(kernel, data, peaks):
for data in bench_type_dict.values():
plot_bandwidth_throughput(kernel, data, peaks, loads=True)
plot_bandwidth_throughput(kernel, data, peaks, loads=False)
plot_flops_throughput(kernel, data, peaks)
plot_flops_throughput(kernel, data, peaks, verbose)
6 changes: 4 additions & 2 deletions src/common/blaslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def optimize_level_1(
precision,
machine,
interleave_factor,
instrs=None,
vec_tail=None,
inter_tail="recursive",
):
vec_width = machine.vec_width(precision)
memory = machine.mem_type
instructions = machine.get_instructions(precision)
if instrs is None:
instrs = machine.get_instructions(precision)

if vec_tail is None:
vec_tail = "cut_and_predicate" if machine.supports_predication else "cut"
Expand All @@ -44,7 +46,7 @@ def optimize_level_1(
)

proc = cleanup(proc)
proc = replace_all_stmts(proc, instructions)
proc = replace_all_stmts(proc, instrs)
return proc


Expand Down
4 changes: 2 additions & 2 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def export_perf_features(kernel_name, perf_features):
json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": "))


def variants_generator(blas_op, opt_precisions=("f32", "f64")):
def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon)):
def generate(proc, loop_name, *args, globals=None, **kwargs):
perf_features = {}
for precision in ("f32", "f64"):
Expand All @@ -126,7 +126,7 @@ def generate(proc, loop_name, *args, globals=None, **kwargs):
loop = stride_1.find_loop(loop_name)
algorithm = get_perf_features(stride_1)

if precision in opt_precisions:
if precision in opt_precisions and C.Machine.mem_type in targets:
stride_1 = blas_op(
stride_1, loop, precision, C.Machine, *args, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion src/common/higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def rewrite(proc, block, subproc_name, *args, rc=False, **kwargs):
proc, subproc = extract_subproc(proc, block, subproc_name)
subproc_sched = op(subproc, *args, **kwargs)
call = proc.forward(block)[0]
proc = call_eqv(proc, call, subproc)
proc = call_eqv(proc, call, subproc_sched)
if not rc:
return proc
return proc, (subproc, subproc_sched)
Expand Down
101 changes: 53 additions & 48 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,69 +585,72 @@ def tile_loops_top_down(proc, loop_tile_pairs):
return proc, [proc.forward(l) for l in inner_loops]


def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True):
def tile_loops_bottom_up(proc, loop, tiles, const_allocs=True, tail="cut_and_guard"):
loop = proc.forward(loop)

loops = [(loop, tiles[0])]
cur_loop = loop
for i in tiles[:-1]:
if not len(cur_loop.body()) == 1:
raise BLAS_SchedulingError("All loop must have a body length of 1")
if not isinstance(cur_loop.body()[0], ForCursor):
raise BLAS_SchedulingError("Did not find a nested loop")
cur_loop = cur_loop.body()[0]

loops = []
cur_loop = loop
for i in tiles:
loops.append((cur_loop, i))
cur_loop = cur_loop.body()[0]
for tile in tiles[1:]:
cur_loop = get_inner_loop(proc, cur_loop)
loops.append((cur_loop, tile))

def get_depth(loop):
if not isinstance(loop, (ForCursor, IfCursor)):
def get_depth(proc, scope):
scope = proc.forward(scope)
if not isinstance(scope, (ForCursor, IfCursor)):
return 0
return max([get_depth(i) for i in loop.body()]) + 1
return max([get_depth(proc, i) for i in scope.body()]) + 1

def push_loop_in(proc, loop, depth, size=None):
loop = proc.forward(loop)
allocs = list(filter_cursors(is_alloc)(proc, loop.body()))
def push_scope_in(proc, scope, depth, size=None):
scope = proc.forward(scope)
allocs = list(filter_cursors(is_alloc)(proc, scope.body()))
proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs)

if const_allocs and size:
proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs)
loop = proc.forward(loop)
count = len(loop.body())
for stmt in list(loop.body())[:-1]:
proc = fission(proc, stmt.after())
loop = proc.forward(loop)
loops = []
for i in range(count):
loops.append(loop)
loop = loop.next()
for loop in loops:
if get_depth(loop) <= depth:

scope = proc.forward(scope)
count = len(scope.body())
scopes = [scope]
for stmt in list(scope.body())[:-1]:
proc, (scope1, scope2) = fission_(proc, stmt.after(), rc=True)
scopes = scopes[:-1]
scopes += [scope1, scope2]

for scope in scopes:
if get_depth(proc, scope) <= depth:
continue
loop = proc.forward(loop)
child = loop.body()[0]
if isinstance(child, ForCursor):
proc = reorder_loops(proc, loop)
forwarded_loop = proc.forward(loop)
elif isinstance(child, IfCursor):
scope = proc.forward(scope)
child = scope.body()[0]
if isinstance(child, (ForCursor, IfCursor)):
proc = lift_scope(proc, child)
child = proc.forward(child)
forwarded_loop = child.body()[0]
forwarded_scope = child.body()[0]
else:
continue
proc = push_loop_in(proc, forwarded_loop, depth, size)
proc = push_scope_in(proc, forwarded_scope, depth, size)
return proc

guards = 0
for depth, (loop, tile) in enumerate(loops[::-1]):
if tile is not None:
proc, (_, inner, tail) = divide_loop_(
proc, loop, tile, tail="cut_and_guard", rc=True
)
proc = push_loop_in(proc, inner, depth + 1, tile)
proc = push_loop_in(proc, tail.body()[0], depth + 1, tile)
else:
proc = push_loop_in(proc, loop, depth + 1)

if tail == "cut_and_guard":
if tile is not None:
proc, (_, inner, tail_l) = divide_loop_(
proc, loop, tile, tail=tail, rc=True
)
proc = simplify(proc)
proc = push_scope_in(proc, inner, depth + 1, tile)
proc = push_scope_in(proc, tail_l.body()[0], depth + 1, tile)
else:
proc = push_scope_in(proc, loop, depth + 1)
elif tail == "guard":
if tile is not None:
proc, (_, inner, _) = divide_loop_(proc, loop, tile, tail=tail, rc=True)
proc = simplify(proc)
proc = push_scope_in(proc, inner.body()[0], guards + 1)
guards += 1
proc = push_scope_in(proc, inner, depth + guards + 1, tile)
else:
proc = push_scope_in(proc, loop, depth + guards + 1)
return proc


Expand Down Expand Up @@ -1393,7 +1396,9 @@ def add_loop_nest(loop):
for src_dim, _ in reversed(list(enumerate(alloc.shape()))):
for dst_dim, size in reversed(divisions[src_dim][1:]):
proc = divide_dim(proc, alloc, src_dim, size)
# proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut") # TODO: This should be enabled but it slows down compilation
proc = apply(divide_loop_)(
proc, loop_nest[src_dim], size, tail="cut"
) # TODO: This should be enabled but it slows down compilation
perm.append(dst_dim)
perm.append(divisions[src_dim][0][0])

Expand Down
8 changes: 1 addition & 7 deletions src/level1/asum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,5 @@ def asum(n: size, x: [f32][n] @ DRAM, result: f32 @ DRAM):
### EXO_LOC ALGORITHM END ###

### EXO_LOC SCHEDULE START ###
def schedule_asum(asum, loop, precision, machine, interleave_factor):
if machine.mem_type is not AVX2:
return asum
return optimize_level_1(asum, loop, precision, machine, interleave_factor)


variants_generator(schedule_asum)(asum, "i", 8, globals=globals())
variants_generator(optimize_level_1, targets=(AVX2,))(asum, "i", 8, globals=globals())
### EXO_LOC SCHEDULE END ###
Loading
Loading