diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ff7019ee0..d7b2da6f6 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -372,6 +372,7 @@ def main(): tuner_ctx, args.limit, args.num_subgroups, + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse, ) for candidate_num, spec in enumerate(specs): spec_dir = Path(args.output) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 50a36d02f..c20325249 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -169,15 +169,14 @@ def generate_tile_and_fuse_constraints( m_tiles, n_tiles, k_tiles, subgroup_m_tiles, subgroup_n_tiles = tile_sizes intrinsic_mn, intrinsic_k = intrinsic_size wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_x == wg_threads, wg_y == 1, wg_z == 1] + wg_threads = wg_x + constraints = [wg_y == 1, wg_z == 1] constraints += [subgroup_size == 64, wg_threads <= 1024] constraints += [ get_mfma_intrinsic_constraints( problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics ) ] - subgroup_k_count = 1 constraints += [ m_tiles[-1] >= intrinsic_mn, @@ -192,9 +191,9 @@ def generate_tile_and_fuse_constraints( constraints += [m_shape % m == 0 for m, m_shape in zip(m_tiles, M)] constraints += [n_shape % n == 0 for n, n_shape in zip(n_tiles, N)] constraints += [k_shape % k == 0 for k, k_shape in zip(k_tiles[:-1], K[:-1])] - constraints += [m >= 0 for m in m_tiles] - constraints += [n >= 0 for n in n_tiles] - constraints += [k >= 0 for k in k_tiles] + constraints += [m >= 1 for m in m_tiles] + constraints += [n >= 1 for n in n_tiles] + constraints += [k >= 1 for k in k_tiles] constraints += [K[-1] % (k_tiles[-1] * intrinsic_k) == 0] constraints += [m <= m_shape for m, m_shape in zip(m_tiles, M)] constraints += [n <= n_shape for n, n_shape in zip(n_tiles, N)] @@ -203,29 +202,27 @@ def generate_tile_and_fuse_constraints( for x in (subgroup_m_count, subgroup_n_count): constraints += [x >= 1, x <= 32] - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - constraints += [math.prod(subgroup_m_tiles) == subgroup_m_tile_count] - constraints += [math.prod(subgroup_n_tiles) == subgroup_n_tile_count] constraints += [ - m % m_subgroup == 0 for m, m_subgroup in zip(m_tiles, subgroup_m_tiles) + m % m_subgroup == 0 + for m, m_subgroup in zip(m_tiles[:-1], subgroup_m_tiles[:-1]) ] constraints += [ - n % n_subgroup == 0 for n, n_subgroup in zip(n_tiles, subgroup_n_tiles) + n % n_subgroup == 0 + for n, n_subgroup in zip(n_tiles[:-1], subgroup_n_tiles[:-1]) ] - constraints += [m_subgroup > 0 for m_subgroup in subgroup_m_tiles] - constraints += [n_subgroup > 0 for n_subgroup in subgroup_n_tiles] + constraints += [m_tiles[-1] % (subgroup_m_tiles[-1] * intrinsic_mn) == 0] + constraints += [n_tiles[-1] % (subgroup_n_tiles[-1] * intrinsic_mn) == 0] + constraints += [m_subgroup >= 1 for m_subgroup in subgroup_m_tiles] + constraints += [n_subgroup >= 1 for n_subgroup in subgroup_n_tiles] constraints += [ - math.prod(m_tiles) == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn + math.prod(m_tiles) + == math.prod(subgroup_m_tiles) * subgroup_m_count * intrinsic_mn ] constraints += [ - math.prod(n_tiles) == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn + math.prod(n_tiles) + == math.prod(subgroup_n_tiles) * subgroup_n_count * intrinsic_mn ] - constraints += [math.prod(k_tiles) == subgroup_k_count * subgroup_k_tile_count] subgroups = subgroup_m_count * subgroup_n_count if num_subgroups > 0: constraints += [subgroups == num_subgroups] diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index d31a76e90..1116adac3 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -124,20 +124,20 @@ def test_generate_tile_and_fuse_constraints_valid_input( tuner_ctx: common.TunerContext, ) -> None: matmul_size = common.ContractionSizes( - M=[4, 32], - N=[6, 64], - K=[8, 128], - B=[2, 16], + M=[32], + N=[64], + K=[128], + B=[2], ) contraction_dims = common.ContractionDimensions( - m=[1, 5], - n=[2, 6], - k=[3, 7], - batch=[0, 4], + m=[1], + n=[2], + k=[3], + batch=[0], ) - lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16) - rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16) - res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32) + lhs_type = common.ShapedType([2, 32, 128], tuner_ctx.type.f16) + rhs_type = common.ShapedType([2, 64, 128], tuner_ctx.type.f16) + res_type = common.ShapedType([2, 32, 64], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, @@ -148,13 +148,13 @@ def test_generate_tile_and_fuse_constraints_valid_input( ) # Define input parameters as z3 Ints m, n, k = ( - [z3.Int("m0"), z3.Int("m1")], - [z3.Int("n0"), z3.Int("n1")], - [z3.Int("k0"), z3.Int("k1")], + [z3.Int("m0")], + [z3.Int("n0")], + [z3.Int("k0")], ) subgroup_m, subgroup_n = ( - [z3.Int("subgroup_m0"), z3.Int("subgroup_m1")], - [z3.Int("subgroup_n0"), z3.Int("subgroup_n1")], + [z3.Int("subgroup_m0")], + [z3.Int("subgroup_n0")], ) subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -198,20 +198,20 @@ def test_generate_tile_and_fuse_constraints_invalid_input( ) -> None: # Define input parameters that should lead to unsatisfiable constraints matmul_size = common.ContractionSizes( - M=[4, 32], - N=[6, 64], - K=[8, 128], - B=[2, 16], + M=[32], + N=[64], + K=[128], + B=[2], ) contraction_dims = common.ContractionDimensions( - m=[1, 5], - n=[2, 6], - k=[3, 7], - batch=[0, 4], + m=[1], + n=[2], + k=[3], + batch=[0], ) - lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16) - rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16) - res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32) + lhs_type = common.ShapedType([2, 32, 128], tuner_ctx.type.f16) + rhs_type = common.ShapedType([2, 64, 128], tuner_ctx.type.f16) + res_type = common.ShapedType([2, 32, 64], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, @@ -222,13 +222,13 @@ def test_generate_tile_and_fuse_constraints_invalid_input( ) # Define input parameters as z3 Ints m, n, k = ( - [z3.Int("m0"), z3.Int("m1")], - [z3.Int("n0"), z3.Int("n1")], - [z3.Int("k0"), z3.Int("k1")], + [z3.Int("m0")], + [z3.Int("n0")], + [z3.Int("k0")], ) subgroup_m, subgroup_n = ( - [z3.Int("subgroup_m0"), z3.Int("subgroup_m1")], - [z3.Int("subgroup_n0"), z3.Int("subgroup_n1")], + [z3.Int("subgroup_m0")], + [z3.Int("subgroup_n0")], ) subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn")