Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-chuang committed Sep 21, 2023
1 parent df95f36 commit c340af0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def compile_ttir_to_ptx_inplace(
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
# TODO (jon-chuang): handle the Hopper case of num_ctas > 1
# (CTAs are Thread Block Clusters in NVIDIA speak)
# (CTAs > 1 in Triton involve Thread Block Clusters only available on Hopper)
num_ctas = 1

extra = {
Expand All @@ -176,7 +176,7 @@ def compile_ttir_to_ptx_inplace(
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
if True:
print(ttgir)
extern_libs = {}
try:
Expand Down Expand Up @@ -430,13 +430,13 @@ def prune_configs(configs, named_args):
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name="triton_kernel_call",
result_types=out_types,
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
).results
)


mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)
Expand Down

0 comments on commit c340af0

Please sign in to comment.