Skip to content

Commit

Permalink
Fix signed extension in q4_1 sharktank kernel (#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Jan 2, 2025
1 parent 56f3d21 commit e98e458
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
outs(%result_fill : !accum_tensor_type) {
^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type):
%bmm_mul = arith.mulf %a_element, %b_element : !a_type
{% if accum_type == a_type %}
%bmm_mul = arith.mulf %a_element, %b_element : !a_type
%bmm_accum = arith.addf %bmm_mul, %out : !a_type
{% else %}
%bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type
%bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type
%a_ext = arith.extf %a_element : !a_type to !accum_type
%b_ext = arith.extf %b_element : !a_type to !accum_type
%bmm_mul = arith.mulf %a_ext, %b_ext : !accum_type
%bmm_accum = arith.addf %bmm_mul, %out : !accum_type
{% endif %}
linalg.yield %bmm_accum : !accum_type
} -> !accum_tensor_type
Expand Down

0 comments on commit e98e458

Please sign in to comment.