From 7ef492e5018a918f24c7b350ca0ab93ae6d91164 Mon Sep 17 00:00:00 2001
From: Hanno Becker <beckphan@amazon.co.uk>
Date: Tue, 17 Dec 2024 05:53:41 +0000
Subject: [PATCH] AArch64: Organize shift-by-immediate Neon instructions

This commit adds the parent class

- VShiftImmediateBasic

for the instructions `ushr` and `shl`, and

- VShiftImmediateRounding

for the instruction `srshr`.

It also adds a new instruction `sshr` to the `VShiftImmediateBasic`
category, and `urshr` to `VShiftImmediateComplex`.

The motivation for those categories is that they feature in Arm's SWOGs:
Most CPUs seem to have the same instruction characteristics for all of
them, so having a parent class helps modelling.
---
 .../aarch64/aarch64_big_experimental.py       |  9 +++--
 slothy/targets/aarch64/aarch64_neon.py        | 34 ++++++++++++++-----
 .../apple_m1_firestorm_experimental.py        | 22 +++++++-----
 .../aarch64/apple_m1_icestorm_experimental.py | 15 ++++----
 slothy/targets/aarch64/cortex_a55.py          | 15 ++++----
 slothy/targets/aarch64/cortex_a72_frontend.py | 10 ++++--
 6 files changed, 70 insertions(+), 35 deletions(-)

diff --git a/slothy/targets/aarch64/aarch64_big_experimental.py b/slothy/targets/aarch64/aarch64_big_experimental.py
index 8f491c41..c53b9667 100644
--- a/slothy/targets/aarch64/aarch64_big_experimental.py
+++ b/slothy/targets/aarch64/aarch64_big_experimental.py
@@ -100,7 +100,8 @@ def get_min_max_objective(slothy):
     (vand, vadd)              : ExecutionUnit.V(),
     (vxtn)                    : ExecutionUnit.V(),
     veor3                     : ExecutionUnit.V(),
-    (vshl, vshl_d, vshli, vshrn) : ExecutionUnit.V1(),
+    (VShiftImmediateBasic,
+     vshl_d, vshli, vshrn)    : ExecutionUnit.V1(), # TODO: Should be V13?
     vusra                     : ExecutionUnit.V1(),
     AESInstruction            : ExecutionUnit.V(),
     Transpose                 : ExecutionUnit.V(),
@@ -141,7 +142,8 @@ def get_min_max_objective(slothy):
     AArch64NeonLogical         : 1,
     (vmovi)                    : 1,
     (vxtn)                     : 1,
-    (vshl, vshl_d, vshli, vshrn) : 1,
+    (VShiftImmediateBasic,
+     vshl_d, vshli, vshrn) : 1,
     (vmul)                     : 2,
     vusra                      : 1,
     (vmlal, vmull)             : 1,
@@ -180,7 +182,8 @@ def get_min_max_objective(slothy):
     (vmul)                    : 5,
     vusra                     : 4, # TODO: Add fwd path
     (vmlal, vmull)            : 4, # TODO: Add fwd path
-    (vshl, vshl_d, vshli, vshrn) : 2,
+    (VShiftImmediateBasic,
+     vshl_d, vshli, vshrn)    : 2,
     (AArch64BasicArithmetic,
      AArch64ConditionalSelect,
      AArch64ConditionalCompare,
diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py
index dcdb5f16..9663c8b3 100644
--- a/slothy/targets/aarch64/aarch64_neon.py
+++ b/slothy/targets/aarch64/aarch64_neon.py
@@ -2528,12 +2528,33 @@ class vsmlal2(Vmlal): # pylint: disable=missing-docstring,invalid-name
     inputs = ["Va", "Vb"]
     in_outs=["Vd"]
 
-class vsrshr(AArch64Instruction): # pylint: disable=missing-docstring,invalid-name
+class VShiftImmediateBasic(AArch64Instruction):
+    pass
+
+class VShiftImmediateRounding(AArch64Instruction):
+    pass
+
+class vsrshr(VShiftImmediateRounding): # pylint: disable=missing-docstring,invalid-name
     pattern = "srshr <Vd>.<dt0>, <Va>.<dt1>, <imm>"
     inputs = ["Va"]
     outputs = ["Vd"]
 
-class vshl(AArch64Instruction): # pylint: disable=missing-docstring,invalid-name
+class vurshr(VShiftImmediateRounding): # pylint: disable=missing-docstring,invalid-name
+    pattern = "urshr <Vd>.<dt0>, <Va>.<dt1>, <imm>"
+    inputs = ["Va"]
+    outputs = ["Vd"]
+
+class vsshr(VShiftImmediateBasic): # pylint: disable=missing-docstring,invalid-name
+    pattern = "sshr <Vd>.<dt0>, <Va>.<dt1>, <imm>"
+    inputs = ["Va"]
+    outputs = ["Vd"]
+
+class vushr(VShiftImmediateBasic): # pylint: disable=missing-docstring,invalid-name
+    pattern = "ushr <Vd>.<dt0>, <Va>.<dt1>, <imm>"
+    inputs = ["Va"]
+    outputs = ["Vd"]
+
+class vshl(VShiftImmediateBasic): # pylint: disable=missing-docstring,invalid-name
     pattern = "shl <Vd>.<dt0>, <Va>.<dt1>, <imm>"
     inputs = ["Va"]
     outputs = ["Vd"]
@@ -2604,11 +2625,6 @@ def make(cls, src, force=False):
             raise Instruction.ParsingException("Instruction ignored")
         return AArch64Instruction.build(cls, src)
 
-class vushr(AArch64Instruction): # pylint: disable=missing-docstring,invalid-name
-    pattern = "ushr <Vd>.<dt0>, <Va>.<dt1>, <imm>"
-    inputs = ["Va"]
-    outputs = ["Vd"]
-
 class Transpose(AArch64Instruction): # pylint: disable=missing-docstring,invalid-name
     pass
 
@@ -3222,7 +3238,7 @@ def eor3_fusion_cb():
     """
     Example for a fusion call back. Allows to merge two eor instruction with
     two inputs into one eor with three inputs. Such technique can help perform
-    transformations in case of differences between uArchs. 
+    transformations in case of differences between uArchs.
     Note: This is not used in any real (crypto) example. This is merely a PoC.
     """
     def core(inst,t,log=None):
@@ -3285,7 +3301,7 @@ def eor3_splitting_cb():
     """
     Example for a splitting call back. Allows to split one eor instruction with
     three inputs into two eors with two inputs. Such technique can help perform
-    transformations in case of differences between uArchs. 
+    transformations in case of differences between uArchs.
     Note: This is not used in any real (crypto) example. This is merely a PoC.
     """
     def core(inst,t,log=None):
diff --git a/slothy/targets/aarch64/apple_m1_firestorm_experimental.py b/slothy/targets/aarch64/apple_m1_firestorm_experimental.py
index 5c327d32..8cc51cd5 100644
--- a/slothy/targets/aarch64/apple_m1_firestorm_experimental.py
+++ b/slothy/targets/aarch64/apple_m1_firestorm_experimental.py
@@ -119,8 +119,11 @@ def get_min_max_objective(slothy):
      vqrdmulh, vqrdmulh_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vushr, vusra, vshl,
-     vand, vbic, ASimdCompare): ExecutionUnit.V(),
+     vsrshr, vusra,
+     vand, vbic, ASimdCompare
+     VShiftImmediateBasic,
+     VShiftImmediateRounding
+     ): ExecutionUnit.V(),
     (vadd, vsub,
      trn1, trn2): ExecutionUnit.V(),
     Vins: ExecutionUnit.V(),  # guessed
@@ -183,8 +186,10 @@ def get_min_max_objective(slothy):
      vmls, vmls_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vushr, vusra, vshl,
-     vand, vbic, ASimdCompare): 1,
+     vusra,
+     vand, vbic, ASimdCompare,
+     VShiftImmediateRounding,
+     VShiftImmediateBasic): 1,
     (vadd, vsub,
      trn1, trn2): 1,
 
@@ -237,9 +242,10 @@ def get_min_max_objective(slothy):
      vmla, vmla_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vusra): 3,
-    (vshl, vushr,
-     vand, vbic, ASimdCompare): 2,
+     vusra): 3,
+    VShiftImmediateRounding: 3,
+    (vand, vbic, ASimdCompare,
+     VShiftImmediateBasic): 2,
     (vadd, vsub,
      trn1, trn2): 2,
     Vins: 2,  # or something less than 13
@@ -293,7 +299,7 @@ def get_latency(src, out_idx, dst):
     if instclass_src == umaddl_wform and instclass_dst == umaddl_wform and \
        src.args_out[0] == dst.args_in[2]:
         return (3, lambda t_src, t_dst: t_dst.program_start_var == t_src.program_start_var + 1)
-    
+
     return latency
 
 
diff --git a/slothy/targets/aarch64/apple_m1_icestorm_experimental.py b/slothy/targets/aarch64/apple_m1_icestorm_experimental.py
index dd2fa4d5..879e674c 100644
--- a/slothy/targets/aarch64/apple_m1_icestorm_experimental.py
+++ b/slothy/targets/aarch64/apple_m1_icestorm_experimental.py
@@ -97,8 +97,9 @@ def get_min_max_objective(slothy):
      vqrdmulh, vqrdmulh_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vushr, vusra, vshl,
-     vand, vbic, ASimdCompare): ExecutionUnit.V(),
+     vusra, vand, vbic, ASimdCompare,
+     VShiftImmediateBasic,
+     VShiftImmediateRounding): ExecutionUnit.V(),
 
     (vadd, vsub,
      trn1, trn2): ExecutionUnit.V(),
@@ -153,8 +154,9 @@ def get_min_max_objective(slothy):
      vmls, vmls_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vushr, vusra, vshl,
-     vand, vbic, ASimdCompare): 1,
+     vusra, vand, vbic, ASimdCompare,
+     VShiftImmediateBasic,
+     VShiftImmediateRounding): 1,
     (vadd, vsub,
      trn1, trn2): 1,
 
@@ -207,8 +209,9 @@ def get_min_max_objective(slothy):
      vmla, vmla_lane,
      vqdmulh_lane,
      vmull, vmlal,
-     vsrshr, vusra): 3,
-    (vshl, vushr,
+     vusra): 3,
+    VShiftImmediateRounding: 3,
+    (VShiftImmediateBasic,
      vand, vbic, ASimdCompare): 2,
     (vadd, vsub,
      trn1, trn2): 2,
diff --git a/slothy/targets/aarch64/cortex_a55.py b/slothy/targets/aarch64/cortex_a55.py
index 37dac2ee..7cdadb13 100644
--- a/slothy/targets/aarch64/cortex_a55.py
+++ b/slothy/targets/aarch64/cortex_a55.py
@@ -109,11 +109,13 @@ def get_min_max_objective(slothy):
         vmla, vmla_lane,
         vqrdmulh, vqrdmulh_lane,
         vqdmulh_lane,
-        vsrshr, vand, vbic,
+        vand, vbic,
         Ldr_Q,
         Str_Q,
         q_ldr1_stack, Q_Ld2_Lane_Post_Inc,
-        Vmull, Vmlal, vushr, vusra
+        Vmull, Vmlal, vusra,
+        vushr, vsshr,
+        VShiftImmediateRounding,
     ): [[ExecutionUnit.VEC0, ExecutionUnit.VEC1]],  # these instructions use both VEC0 and VEC1
 
     St4 : [[ExecutionUnit.VEC0, ExecutionUnit.VEC1, ExecutionUnit.SCALAR_LOAD,
@@ -176,7 +178,7 @@ def get_min_max_objective(slothy):
     ( vadd, vsub, vmov,
       vmul, vmul_lane, vmls, vmls_lane,
       vqrdmulh, vqrdmulh_lane, vqdmulh_lane, Vmull, Vmlal,
-      vsrshr, umov_d ) : 1,
+      umov_d ) : 1,
     (trn2, trn1, ASimdCompare): 1,
     ( Ldr_Q ) : 2,
     ( Str_Q ) : 1,
@@ -199,7 +201,8 @@ def get_min_max_objective(slothy):
      adcs_zero_r_to_zero, cmn) : 1,
     (cmp_xzr2, sub, subs_wform, asr_wform, sbcs_zero_to_zero, ngc_zero) : 1,
     (bfi) : 1,
-    (vshl, vshl, vushr) : 1,
+    VShiftImmediateRounding : 1,
+    VShiftImmediateBasic : 1,
     (vusra) : 1,
     (vand, vbic) : 1,
     (vuzp1, vuzp2) : 1,
@@ -218,7 +221,6 @@ def get_min_max_objective(slothy):
     is_dform_form_of([vadd, vsub]) : 2,
 
     (trn1, trn2, ASimdCompare): 2,
-    ( vsrshr ) : 3,
     ( vmul, vmul_lane, vmls, vmls_lane,
       vqrdmulh, vqrdmulh_lane, vqdmulh_lane, Vmull, Vmlal) : 4,
     ( Ldr_Q, Str_Q ) : 4,
@@ -244,7 +246,8 @@ def get_min_max_objective(slothy):
      sub, subs_wform, asr_wform, sbcs_zero_to_zero, cmp_xzr2,
      ngc_zero) : 1,
     (bfi) : 2,
-    (vshl, vushr) : 2,
+    VShiftImmediateRounding : 3,
+    VShiftImmediateBasic : 2,
     (vusra) : 3,
     (vand, vbic) : 1,
     (vuzp1, vuzp2) : 2,
diff --git a/slothy/targets/aarch64/cortex_a72_frontend.py b/slothy/targets/aarch64/cortex_a72_frontend.py
index 1a74a7dc..c2bd0554 100644
--- a/slothy/targets/aarch64/cortex_a72_frontend.py
+++ b/slothy/targets/aarch64/cortex_a72_frontend.py
@@ -131,7 +131,8 @@ def get_min_max_objective(slothy):
 
     (add, add_imm, add_lsl, add_lsr) : ExecutionUnit.SCALAR(),
 
-    vsrshr : [ExecutionUnit.ASIMD1],
+    (VShiftImmediateRounding,
+     VShiftImmediateBasic): [ExecutionUnit.ASIMD1],
 
     (St4, St2) : [ExecutionUnit.ASIMD0, ExecutionUnit.ASIMD1],
 
@@ -164,7 +165,8 @@ def get_min_max_objective(slothy):
       Ldr_X, Str_X )
       : 1,
 
-    vsrshr : 1,
+    (VShiftImmediateRounding,
+     VShiftImmediateBasic): 1,
 
     St2 : 4,
     St4 : 8,
@@ -195,7 +197,9 @@ def get_min_max_objective(slothy):
 
     (add, add_imm, add_lsl, add_lsr) : 2,
 
-    vsrshr : 3, # approx
+    VShiftImmediateRounding: 3, # approx
+    VShiftImmediateBasic: 3,
+
     St2 : 4,
     St4 : 8,
     Ld4 : 4