Skip to content

Commit

Permalink
upgrade triton dep
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-chuang committed Sep 4, 2023
1 parent 170511d commit 84c0593
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,38 @@ def compile_ttir_to_ptx_inplace(
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
# TODO (jon-chuang): handle the Hopper case of num_ctas > 1
# (CTAs are Thread Block Clusters in NVIDIA speak)
num_ctas = 1

extra = {
'cluster_info': _triton.ClusterInfo(),
'enable_warp_specialization': False,
'enable_persistent': False,
'optimize_epilogue': False,
}
if dump:
print(ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps)
ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability)
ttir = tc.optimize_ttir(ttir, arch=compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability,)
ttgir = tc.optimize_ttgir(ttgir,
num_stages=num_stages, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability, **extra)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
print(ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability)
llir = tc.ttgir_to_llir(ttgir, extern_libs, arch=compute_capability, tma_infos=_triton.TMAInfos())
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
ptx = tc.llir_to_ptx(llir, arch=compute_capability)
if dump:
print(ptx)
name = ptx_get_kernel_name(ptx)
Expand Down Expand Up @@ -419,13 +430,13 @@ def prune_configs(configs, named_args):
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name="triton_kernel_call",
result_types=out_types,
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
).results
)


mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.9,<3.11"
dependencies = [
"absl-py>=1.4.0",
"jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac",
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230822000928/triton_nightly-2.1.0.dev20230822000928-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=39f23718220984746c7fc5831d65c2c990eaf2d755a1c16bcba24946515ef0f6"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 84c0593

Please sign in to comment.