Skip to content

Commit

Permalink
Integrate Triton up to [68aa962e67baa191cec5aac173255abdba80db1a](htt…
Browse files Browse the repository at this point in the history
  • Loading branch information
Google-ML-Automation committed Oct 21, 2024
1 parent c7b8cd5 commit c834be3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ def get_or_create_triton_kernel(
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
backend = backend_init_func(device, compute_capability)
for i, _, v in scalar_args:
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
constants = {k: v for k, v in metaparams.items()}
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})
Expand All @@ -383,7 +384,7 @@ def get_or_create_triton_kernel(
cache_key = (
fn,
tuple(signature.items()),
tuple(vars(specialization_attr).values()),
tuple(specialization_attr.arg_properties),
tuple(constants.items()),
num_warps,
num_stages,
Expand All @@ -403,7 +404,6 @@ def get_or_create_triton_kernel(
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

kernel_hash = abs(hash(cache_key))
Expand Down Expand Up @@ -643,7 +643,7 @@ def prune_configs(configs, named_args, **kwargs):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization_attr.divisible_by_16) else 0,
16 if (i in specialization_attr.divisibility_16) else 0,
)
)
elif i not in specialization_attr.equal_to_1:
Expand Down
4 changes: 2 additions & 2 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@ def test_specialization(self):
# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9))
self.assertEqual(specialization.attrs.divisibility_16, [1, 3, 9])
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.attrs.equal_to_1, (8, 10))
self.assertEqual(specialization.attrs.equal_to_1, [8, 10])


if __name__ == "__main__":
Expand Down

0 comments on commit c834be3

Please sign in to comment.