diff --git a/src/common/rc_wrappers.py b/src/common/rc_wrappers.py index a7a8f31..2bbcbd2 100644 --- a/src/common/rc_wrappers.py +++ b/src/common/rc_wrappers.py @@ -84,15 +84,15 @@ def divide_loop_(proc, loop_cursor, div_const, tail="guard", rc=False): @dataclass class stage_mem_cursors: alloc: AllocCursor - load_stage: Cursor + load: Cursor block: BlockCursor - store_stage: Cursor + store: Cursor def __iter__(self): yield self.alloc - yield self.load_stage + yield self.load yield self.block - yield self.store_stage + yield self.store def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False): @@ -105,12 +105,12 @@ def stage_mem_(proc, block, buff, new_buff_name, accum=False, rc=False): block_first = proc.forward(block_first) block_last = proc.forward(block_last) alloc = block_first.prev().prev() - load_stage = block_first.prev() + load = block_first.prev() block = block_first.as_block().expand(0, len(block) - 1) - store_stage = block_last.next() + store = block_last.next() if not rc: return proc - return proc, stage_mem_cursors(alloc, load_stage, block, store_stage) + return proc, stage_mem_cursors(alloc, load, block, store) @dataclass diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 197e047..69543bc 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -1348,3 +1348,69 @@ def squash_buffers(proc, buffers): proc = squash(proc, buffers[1:], buffer) max_d = max(depths) return proc + + +@dataclass +class pack_mem_cursors: + alloc: AllocCursor + load: Cursor + block: BlockCursor + store: Cursor + + def __iter__(self): + yield self.alloc + yield self.load + yield self.block + yield self.store + + +def pack_mem(proc, block, buffer, shape, name=None, rc=False): + proc, (alloc, load, block, store) = auto_stage_mem( + proc, block, buffer, name, rc=True + ) + bounds = [1] * len(alloc.shape()) + + loop_nest = [[] for _ in alloc.shape()] + + def add_loop_nest(loop): + if is_invalid(proc, loop): + return + for idx, _ in enumerate(alloc.shape()): + loop_nest[idx].append(loop) + loop = loop.body()[0] + + add_loop_nest(load) + add_loop_nest(store) + + divisions = {} + for dst_dim, (src_dim, size) in enumerate(shape): + divisions.setdefault(src_dim, []).append((dst_dim, size)) + bounds[src_dim] *= size + proc = bound_alloc(proc, alloc, bounds) + perm = [] + for src_dim, _ in reversed(list(enumerate(alloc.shape()))): + for dst_dim, size in reversed(divisions[src_dim][1:]): + proc = divide_dim(proc, alloc, src_dim, size) + # proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut") # TODO: This should be enabled but it slows down compilation + perm.append(dst_dim) + perm.append(divisions[src_dim][0][0]) + + perm = perm[::-1] + final_perm = [0] * len(perm) + for i, val in enumerate(perm): + final_perm[val] = i + proc = rearrange_dim(proc, alloc, final_perm) + proc = simplify(proc) + + if not rc: + return proc + + alloc = proc.forward(alloc) + block = proc.forward(block) + + if not is_invalid(proc, load): + load = proc.forward(load) + diff = get_index_in_body(proc, block[0]) - get_index_in_body(proc, load) - 1 + load = load.as_block().expand(0, diff) + + return proc, pack_mem_cursors(alloc, load, block, store) diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 9947d60..d04a675 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -78,25 +78,29 @@ def rewrite(gemm_uk): def schedule_macro( gemm_mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac, do_br=False ): + vw = machine.vec_width(precision) + n_r = vw * n_r_fac gemm_mk = specialize_precision(gemm_mk, precision) for var, max_var in zip(("M", "N", "K"), (max_M, max_N, max_K)): gemm_mk = gemm_mk.add_assertion(f"{var} <= {max_var}") + gemm_mk_starter = gemm_mk gemm_mk = rename(gemm_mk, gemm_mk.name() + "_mk") i_loop = gemm_mk.body()[0] - gemm_mk, (A_alloc, A_load, _, _) = auto_stage_mem( - gemm_mk, i_loop, "A", "packed_A", rc=True - ) - gemm_mk, (B_alloc, B_load, _, _) = auto_stage_mem( - gemm_mk, i_loop, "B", "packed_B", rc=True - ) - gemm_mk = bound_alloc(gemm_mk, A_alloc, (max_M, max_K)) - gemm_mk = bound_alloc(gemm_mk, B_alloc, (max_K, max_N)) - gemm_mk, _ = extract_subproc(gemm_mk, A_load, "A_pack_kernel") - gemm_mk, _ = extract_subproc(gemm_mk, B_load, "B_pack_kernel") - gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute") + packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) + gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) + gemm_mk, _ = extract_subproc( + gemm_mk, cursors.load, "A_pack_kernel" + ) # TODO: Schedule packing kernel + packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) + gemm_mk, cursors = pack_mem(gemm_mk, i_loop, "B", packed_B_shape, "packed_B", rc=1) + gemm_mk, _ = extract_subproc( + gemm_mk, cursors.load, "B_pack_kernel" + ) # TODO: Schedule packing kernel + + gemm_mk, _ = extract_subproc(gemm_mk, i_loop, "compute") return gemm_mk_starter, gemm_mk n_r = machine.vec_width(precision) * n_r_fac j_loop = get_inner_loop(gemm_mk, i_loop) diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index f07734a..b8ce6f8 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -5,7 +5,7 @@ "exo_dot": "8e7a71353e80273839bc522cfc8c889cbe25d8701a8c82a3c2559b12d9e90f5f", "exo_dsdot": "c901be9d30928e042c35aeb9dab421a34db6d593534f31b4967f9685ecf9628b", "exo_gbmv": "cb92744337cfdf3aa97d250a18e540c2e8787380ba88bfef790a3e14aeb19f37", - "exo_gemm": "28f060df6aa28cc7259d8e5379b95f560bb728e9de322d0d752d3113ba9510ea", + "exo_gemm": "a506749d7692fab974679a3901c0bd82bd17f4a0a00b95654d0d72845e5cc144", "exo_gemv": "80d79a5752c20874fbe3d2c94989f5f5663687d405966065a18224800c931559", "exo_ger": "38ba46c410ea9bd1d616add516e9ab82ed811e5785848bce17191d89f81752ce", "exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842", @@ -17,12 +17,12 @@ "exo_swap": "908f3bb8eda4065dbaef9ff0c085a9d8b55eb12d893cc53212dfc905f24d290b", "exo_symm": "f8192433e2e64b600a573ecd2a77a404f79bb0886316c1e483530fff5bfe0cf4", "exo_symv": "f96e0661b0221c69b43d9e476c3f5ba9a3ee41168149232b6a414145dc0d77f5", - "exo_syr": "eb284ae2c8d1424bf32fa791822231a1a1e559117a51beaaafaa4438d10c14ea", + "exo_syr": "3e6894a8a9003ede58e06c57b261a1638fd6249b8a80ed356e147862f06f39aa", "exo_syr2": "9285dc796c9c573cbd974f419563a0c2e3d7507aba3521ee1c89e9e707913332", "exo_syrk": "9894ba92a502df8968c0c4e1e09cb5510a9ad10d77f4e1595570a7d1a2167b4b", "exo_tbmv": "e517f633eeaf1429c2204966a2970e7013054afd1f0bb22795075cfa5e4678db", "exo_tbsv": "faeb1392d2af7dc9cdac9fb707bd7a3273e82c92bcd940e869fb7bc5be14f020", "exo_trmm": "70f7aa84d76fe3be02cbbc5db13fdc9d55b9dd481ba02616f1457a15a653a074", - "exo_trmv": "8cab60255358ae985e8e34f719aca6cd60e090f58e5ef53bc3254a094bb8f9ce", - "exo_trsv": "f39e76f37f370b48df07c03254f7b4cf0a017887f3cd02708f9efc81c196e83f" + "exo_trmv": "99dc88dcc37a11f7a89d58ee2ff9c7805d5c324c443b4dbd6edcf750d3cfe440", + "exo_trsv": "8e1033febc856a710730b2297e6f647b9d61e1711d3102c497ff96110d09b5ee" }