Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade triton dependency to dev20230822 nightly #229

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,38 @@ def compile_ttir_to_ptx_inplace(
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
# TODO (jon-chuang): handle the Hopper case of num_ctas > 1
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
# (CTAs are Thread Block Clusters in NVIDIA speak)
num_ctas = 1

extra = {
'cluster_info': _triton.ClusterInfo(),
'enable_warp_specialization': False,
'enable_persistent': False,
'optimize_epilogue': False,
}
if dump:
print(ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps)
ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability)
ttir = tc.optimize_ttir(ttir, arch=compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability,)
ttgir = tc.optimize_ttgir(ttgir,
num_stages=num_stages, num_warps=num_warps, num_ctas=num_ctas, arch=compute_capability, **extra)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
print(ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability)
llir = tc.ttgir_to_llir(ttgir, extern_libs, arch=compute_capability, tma_infos=_triton.TMAInfos())
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
ptx = tc.llir_to_ptx(llir, arch=compute_capability)
if dump:
print(ptx)
name = ptx_get_kernel_name(ptx)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.9,<3.11"
dependencies = [
"absl-py>=1.4.0",
"jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac",
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230822000928/triton_nightly-2.1.0.dev20230822000928-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=39f23718220984746c7fc5831d65c2c990eaf2d755a1c16bcba24946515ef0f6"
]

[project.optional-dependencies]
Expand Down