Skip to content

Commit

Permalink
Pack input matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi committed Mar 17, 2024
1 parent f0aab7e commit 72553bb
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 22 deletions.
14 changes: 7 additions & 7 deletions src/common/rc_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def divide_loop_(proc, loop_cursor, div_const, tail="guard", rc=False):
@dataclass
class stage_mem_cursors:
alloc: AllocCursor
load_stage: Cursor
load: Cursor
block: BlockCursor
store_stage: Cursor
store: Cursor

def __iter__(self):
yield self.alloc
yield self.load_stage
yield self.load
yield self.block
yield self.store_stage
yield self.store


def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False):
Expand All @@ -105,12 +105,12 @@ def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False):
block_first = proc.forward(block_first)
block_last = proc.forward(block_last)
alloc = block_first.prev().prev()
load_stage = block_first.prev()
load = block_first.prev()
block = block_first.as_block().expand(0, len(block) - 1)
store_stage = block_last.next()
store = block_last.next()
if not rc:
return proc
return proc, stage_mem_cursors(alloc, load_stage, block, store_stage)
return proc, stage_mem_cursors(alloc, load, block, store)


@dataclass
Expand Down
66 changes: 66 additions & 0 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,3 +1348,69 @@ def squash_buffers(proc, buffers):
proc = squash(proc, buffers[1:], buffer)
max_d = max(depths)
return proc


@dataclass
class pack_mem_cursors:
alloc: AllocCursor
load: Cursor
block: BlockCursor
store: Cursor

def __iter__(self):
yield self.alloc
yield self.load
yield self.block
yield self.store


def pack_mem(proc, block, buffer, shape, name=None, rc=False):
proc, (alloc, load, block, store) = auto_stage_mem(
proc, block, buffer, name, rc=True
)
bounds = [1] * len(alloc.shape())

loop_nest = [[] for _ in alloc.shape()]

def add_loop_nest(loop):
if is_invalid(proc, loop):
return
for idx, _ in enumerate(alloc.shape()):
loop_nest[idx].append(loop)
loop = loop.body()[0]

add_loop_nest(load)
add_loop_nest(store)

divisions = {}
for dst_dim, (src_dim, size) in enumerate(shape):
divisions.setdefault(src_dim, []).append((dst_dim, size))
bounds[src_dim] *= size
proc = bound_alloc(proc, alloc, bounds)
perm = []
for src_dim, _ in reversed(list(enumerate(alloc.shape()))):
for dst_dim, size in reversed(divisions[src_dim][1:]):
proc = divide_dim(proc, alloc, src_dim, size)
# proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut") # TODO: This should be enabled but it slows down compilation
perm.append(dst_dim)
perm.append(divisions[src_dim][0][0])

perm = perm[::-1]
final_perm = [0] * len(perm)
for i, val in enumerate(perm):
final_perm[val] = i
proc = rearrange_dim(proc, alloc, final_perm)
proc = simplify(proc)

if not rc:
return proc

alloc = proc.forward(alloc)
block = proc.forward(block)

if not is_invalid(proc, load):
load = proc.forward(load)
diff = get_index_in_body(proc, block[0]) - get_index_in_body(proc, load) - 1
load = load.as_block().expand(0, diff)

return proc, pack_mem_cursors(alloc, load, block, store)
26 changes: 15 additions & 11 deletions src/level3/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,29 @@ def rewrite(gemm_uk):
def schedule_macro(
gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac, do_br=False
):
vw = machine.vec_width(precision)
n_r = vw * n_r_fac
gemm_mk = specialize_precision(gemm_mk, precision)
for var, max_var in zip(("M", "N", "K"), (max_M, max_N, max_K)):
gemm_mk = gemm_mk.add_assertion(f"{var} <= {max_var}")

gemm_mk_starter = gemm_mk
gemm_mk = rename(gemm_mk, gemm_mk.name() + "_mk")
i_loop = gemm_mk.body()[0]
gemm_mk, (A_alloc, A_load, _, _) = auto_stage_mem(
gemm_mk, i_loop, "A", "packed_A", rc=True
)
gemm_mk, (B_alloc, B_load, _, _) = auto_stage_mem(
gemm_mk, i_loop, "B", "packed_B", rc=True
)
gemm_mk = bound_alloc(gemm_mk, A_alloc, (max_M, max_K))

gemm_mk = bound_alloc(gemm_mk, B_alloc, (max_K, max_N))
gemm_mk, _ = extract_subproc(gemm_mk, A_load, "A_pack_kernel")
gemm_mk, _ = extract_subproc(gemm_mk, B_load, "B_pack_kernel")
gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute")
packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r))
gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1)
gemm_mk, _ = extract_subproc(
gemm_mk, cursors.load, "A_pack_kernel"
) # TODO: Schedule packing kernel

packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r))
gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1)
gemm_mk, _ = extract_subproc(
gemm_mk, cursors.load, "B_pack_kernel"
) # TODO: Schedule packing kernel

gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute")
return gemm_mk_starter, gemm_mk
n_r = machine.vec_width(precision) * n_r_fac
j_loop = get_inner_loop(gemm_mk, i_loop)
Expand Down
8 changes: 4 additions & 4 deletions test/codegen/reference/sha256/avx2.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"exo_dot": "8e7a71353e80273839bc522cfc8c889cbe25d8701a8c82a3c2559b12d9e90f5f",
"exo_dsdot": "c901be9d30928e042c35aeb9dab421a34db6d593534f31b4967f9685ecf9628b",
"exo_gbmv": "cb92744337cfdf3aa97d250a18e540c2e8787380ba88bfef790a3e14aeb19f37",
"exo_gemm": "28f060df6aa28cc7259d8e5379b95f560bb728e9de322d0d752d3113ba9510ea",
"exo_gemm": "a506749d7692fab974679a3901c0bd82bd17f4a0a00b95654d0d72845e5cc144",
"exo_gemv": "80d79a5752c20874fbe3d2c94989f5f5663687d405966065a18224800c931559",
"exo_ger": "38ba46c410ea9bd1d616add516e9ab82ed811e5785848bce17191d89f81752ce",
"exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842",
Expand All @@ -17,12 +17,12 @@
"exo_swap": "908f3bb8eda4065dbaef9ff0c085a9d8b55eb12d893cc53212dfc905f24d290b",
"exo_symm": "f8192433e2e64b600a573ecd2a77a404f79bb0886316c1e483530fff5bfe0cf4",
"exo_symv": "f96e0661b0221c69b43d9e476c3f5ba9a3ee41168149232b6a414145dc0d77f5",
"exo_syr": "eb284ae2c8d1424bf32fa791822231a1a1e559117a51beaaafaa4438d10c14ea",
"exo_syr": "3e6894a8a9003ede58e06c57b261a1638fd6249b8a80ed356e147862f06f39aa",
"exo_syr2": "9285dc796c9c573cbd974f419563a0c2e3d7507aba3521ee1c89e9e707913332",
"exo_syrk": "9894ba92a502df8968c0c4e1e09cb5510a9ad10d77f4e1595570a7d1a2167b4b",
"exo_tbmv": "e517f633eeaf1429c2204966a2970e7013054afd1f0bb22795075cfa5e4678db",
"exo_tbsv": "faeb1392d2af7dc9cdac9fb707bd7a3273e82c92bcd940e869fb7bc5be14f020",
"exo_trmm": "70f7aa84d76fe3be02cbbc5db13fdc9d55b9dd481ba02616f1457a15a653a074",
"exo_trmv": "8cab60255358ae985e8e34f719aca6cd60e090f58e5ef53bc3254a094bb8f9ce",
"exo_trsv": "f39e76f37f370b48df07c03254f7b4cf0a017887f3cd02708f9efc81c196e83f"
"exo_trmv": "99dc88dcc37a11f7a89d58ee2ff9c7805d5c324c443b4dbd6edcf750d3cfe440",
"exo_trsv": "8e1033febc856a710730b2297e6f647b9d61e1711d3102c497ff96110d09b5ee"
}

0 comments on commit 72553bb

Please sign in to comment.