[x265] [PATCH v2 5/9] AArch64: Optimise partialButterfly32_neon

Hari Limaye hari.limaye at arm.com
Tue Aug 27 15:10:53 UTC 2024


Optimise the Neon implementation of partialButterfly32 to process four
lines at a time to make use of the full width of Neon vector registers,
avoiding widening to 32-bit values where possible, and replacing the
addition of a rounding constant with rounding shift instructions.

Relative performance observed compared to the existing implementation:

  Neoverse N1: 1.37x
  Neoverse V1: 1.59x
  Neoverse N2: 1.30x
  Neoverse V2: 1.28x
---
 source/common/aarch64/dct-prim.cpp | 243 ++++++++++++++++++++---------
 1 file changed, 168 insertions(+), 75 deletions(-)

diff --git a/source/common/aarch64/dct-prim.cpp b/source/common/aarch64/dct-prim.cpp
index c880bc72c..17ab7ed19 100644
--- a/source/common/aarch64/dct-prim.cpp
+++ b/source/common/aarch64/dct-prim.cpp
@@ -360,100 +360,193 @@ static inline void partialButterfly16_neon(const int16_t *src, int16_t *dst)
     }
 }
 
-
-static void partialButterfly32(const int16_t *src, int16_t *dst, int shift, int line)
+template<int shift>
+static inline void partialButterfly32_neon(const int16_t *src, int16_t *dst)
 {
-    int j, k;
-    const int add = 1 << (shift - 1);
+    const int line = 32;
 
+    int16x8_t O[line][2];
+    int32x4_t EO[line][2];
+    int32x4_t EEO[line];
+    int32x4_t EEEE[line / 2];
+    int32x4_t EEEO[line / 2];
 
-    for (j = 0; j < line; j++)
+    for (int i = 0; i < line; i += 2)
     {
-        int32x4_t VE[4], VO0, VO1, VO2, VO3;
-        int32x4_t VEE[2], VEO[2];
-        int32x4_t VEEE, VEEO;
-        int EEEE[2], EEEO[2];
-
-        int16x8x4_t inputs;
-        inputs = vld1q_s16_x4(src);
-        int16x8x4_t in_rev;
-
-        in_rev.val[1] = rev16(inputs.val[2]);
-        in_rev.val[0] = rev16(inputs.val[3]);
-
-        VE[0] = vaddl_s16(vget_low_s16(inputs.val[0]), vget_low_s16(in_rev.val[0]));
-        VE[1] = vaddl_high_s16(inputs.val[0], in_rev.val[0]);
-        VO0 = vsubl_s16(vget_low_s16(inputs.val[0]), vget_low_s16(in_rev.val[0]));
-        VO1 = vsubl_high_s16(inputs.val[0], in_rev.val[0]);
-        VE[2] = vaddl_s16(vget_low_s16(inputs.val[1]), vget_low_s16(in_rev.val[1]));
-        VE[3] = vaddl_high_s16(inputs.val[1], in_rev.val[1]);
-        VO2 = vsubl_s16(vget_low_s16(inputs.val[1]), vget_low_s16(in_rev.val[1]));
-        VO3 = vsubl_high_s16(inputs.val[1], in_rev.val[1]);
-
-        for (k = 1; k < 32; k += 2)
-        {
-            int32x4_t c0 = vmovl_s16(vld1_s16(&g_t32[k][0]));
-            int32x4_t c1 = vmovl_s16(vld1_s16(&g_t32[k][4]));
-            int32x4_t c2 = vmovl_s16(vld1_s16(&g_t32[k][8]));
-            int32x4_t c3 = vmovl_s16(vld1_s16(&g_t32[k][12]));
-            int32x4_t s = vmulq_s32(c0, VO0);
-            s = vmlaq_s32(s, c1, VO1);
-            s = vmlaq_s32(s, c2, VO2);
-            s = vmlaq_s32(s, c3, VO3);
+        int16x8x4_t in_lo = vld1q_s16_x4(src + (i + 0) * line);
+        in_lo.val[2] = rev16(in_lo.val[2]);
+        in_lo.val[3] = rev16(in_lo.val[3]);
+
+        int16x8x4_t in_hi = vld1q_s16_x4(src + (i + 1) * line);
+        in_hi.val[2] = rev16(in_hi.val[2]);
+        in_hi.val[3] = rev16(in_hi.val[3]);
+
+        int32x4_t E0[4];
+        E0[0] = vaddl_s16(vget_low_s16(in_lo.val[0]),
+                          vget_low_s16(in_lo.val[3]));
+        E0[1] = vaddl_s16(vget_high_s16(in_lo.val[0]),
+                          vget_high_s16(in_lo.val[3]));
+        E0[2] = vaddl_s16(vget_low_s16(in_lo.val[1]),
+                          vget_low_s16(in_lo.val[2]));
+        E0[3] = vaddl_s16(vget_high_s16(in_lo.val[1]),
+                          vget_high_s16(in_lo.val[2]));
+
+        int32x4_t E1[4];
+        E1[0] = vaddl_s16(vget_low_s16(in_hi.val[0]),
+                          vget_low_s16(in_hi.val[3]));
+        E1[1] = vaddl_s16(vget_high_s16(in_hi.val[0]),
+                          vget_high_s16(in_hi.val[3]));
+        E1[2] = vaddl_s16(vget_low_s16(in_hi.val[1]),
+                          vget_low_s16(in_hi.val[2]));
+        E1[3] = vaddl_s16(vget_high_s16(in_hi.val[1]),
+                          vget_high_s16(in_hi.val[2]));
+
+        O[i + 0][0] = vsubq_s16(in_lo.val[0], in_lo.val[3]);
+        O[i + 0][1] = vsubq_s16(in_lo.val[1], in_lo.val[2]);
+
+        O[i + 1][0] = vsubq_s16(in_hi.val[0], in_hi.val[3]);
+        O[i + 1][1] = vsubq_s16(in_hi.val[1], in_hi.val[2]);
+
+        int32x4_t EE0[2];
+        E0[3] = rev32(E0[3]);
+        E0[2] = rev32(E0[2]);
+        EE0[0] = vaddq_s32(E0[0], E0[3]);
+        EE0[1] = vaddq_s32(E0[1], E0[2]);
+        EO[i + 0][0] = vsubq_s32(E0[0], E0[3]);
+        EO[i + 0][1] = vsubq_s32(E0[1], E0[2]);
+
+        int32x4_t EE1[2];
+        E1[3] = rev32(E1[3]);
+        E1[2] = rev32(E1[2]);
+        EE1[0] = vaddq_s32(E1[0], E1[3]);
+        EE1[1] = vaddq_s32(E1[1], E1[2]);
+        EO[i + 1][0] = vsubq_s32(E1[0], E1[3]);
+        EO[i + 1][1] = vsubq_s32(E1[1], E1[2]);
+
+        int32x4_t EEE0;
+        EE0[1] = rev32(EE0[1]);
+        EEE0 = vaddq_s32(EE0[0], EE0[1]);
+        EEO[i + 0] = vsubq_s32(EE0[0], EE0[1]);
+
+        int32x4_t EEE1;
+        EE1[1] = rev32(EE1[1]);
+        EEE1 = vaddq_s32(EE1[0], EE1[1]);
+        EEO[i + 1] = vsubq_s32(EE1[0], EE1[1]);
 
-            dst[k * line] = (int16_t)((vaddvq_s32(s) + add) >> shift);
-
-        }
+        int32x4_t t0 = vreinterpretq_s32_s64(
+            vzip1q_s64(vreinterpretq_s64_s32(EEE0),
+                       vreinterpretq_s64_s32(EEE1)));
+        int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64(
+            vzip2q_s64(vreinterpretq_s64_s32(EEE0),
+                       vreinterpretq_s64_s32(EEE1))));
 
-        int32x4_t rev_VE[2];
+        EEEE[i / 2] = vaddq_s32(t0, t1);
+        EEEO[i / 2] = vsubq_s32(t0, t1);
+    }
 
+    for (int k = 1; k < 32; k += 2)
+    {
+        int16_t *d = dst + k * line;
 
-        rev_VE[0] = rev32(VE[3]);
-        rev_VE[1] = rev32(VE[2]);
+        int16x8_t c0_c1 = vld1q_s16(&g_t32[k][0]);
+        int16x8_t c2_c3 = vld1q_s16(&g_t32[k][8]);
+        int16x4_t c0 = vget_low_s16(c0_c1);
+        int16x4_t c1 = vget_high_s16(c0_c1);
+        int16x4_t c2 = vget_low_s16(c2_c3);
+        int16x4_t c3 = vget_high_s16(c2_c3);
 
-        /* EE and EO */
-        for (k = 0; k < 2; k++)
+        for (int i = 0; i < line; i += 4)
         {
-            VEE[k] = vaddq_s32(VE[k], rev_VE[k]);
-            VEO[k] = vsubq_s32(VE[k], rev_VE[k]);
+            int32x4_t t[4];
+            for (int j = 0; j < 4; ++j) {
+                t[j] = vmull_s16(c0, vget_low_s16(O[i + j][0]));
+                t[j] = vmlal_s16(t[j], c1, vget_high_s16(O[i + j][0]));
+                t[j] = vmlal_s16(t[j], c2, vget_low_s16(O[i + j][1]));
+                t[j] = vmlal_s16(t[j], c3, vget_high_s16(O[i + j][1]));
+            }
+
+            int32x4_t t0123 = vpaddq_s32(vpaddq_s32(t[0], t[1]),
+                                         vpaddq_s32(t[2], t[3]));
+            int16x4_t res = vrshrn_n_s32(t0123, shift);
+            vst1_s16(d, res);
+
+            d += 4;
         }
-        for (k = 2; k < 32; k += 4)
-        {
-            int32x4_t c0 = vmovl_s16(vld1_s16(&g_t32[k][0]));
-            int32x4_t c1 = vmovl_s16(vld1_s16(&g_t32[k][4]));
-            int32x4_t s = vmulq_s32(c0, VEO[0]);
-            s = vmlaq_s32(s, c1, VEO[1]);
+    }
+
+    for (int k = 2; k < 32; k += 4)
+    {
+        int16_t *d = dst + k * line;
 
-            dst[k * line] = (int16_t)((vaddvq_s32(s) + add) >> shift);
+        int32x4_t c0 = vmovl_s16(vld1_s16(&g_t32[k][0]));
+        int32x4_t c1 = vmovl_s16(vld1_s16(&g_t32[k][4]));
 
+        for (int i = 0; i < line; i += 4)
+        {
+            int32x4_t t[4];
+            for (int j = 0; j < 4; ++j) {
+                t[j] = vmulq_s32(c0, EO[i + j][0]);
+                t[j] = vmlaq_s32(t[j], c1, EO[i + j][1]);
+            }
+
+            int32x4_t t0123 = vpaddq_s32(vpaddq_s32(t[0], t[1]),
+                                         vpaddq_s32(t[2], t[3]));
+            int16x4_t res = vrshrn_n_s32(t0123, shift);
+            vst1_s16(d, res);
+
+            d += 4;
         }
+    }
+
+    for (int k = 4; k < 32; k += 8)
+    {
+        int16_t *d = dst + k * line;
+
+        int32x4_t c = vmovl_s16(vld1_s16(&g_t32[k][0]));
 
-        int32x4_t tmp = rev32(VEE[1]);
-        VEEE = vaddq_s32(VEE[0], tmp);
-        VEEO = vsubq_s32(VEE[0], tmp);
-        for (k = 4; k < 32; k += 8)
+        for (int i = 0; i < line; i += 4)
         {
-            int32x4_t c = vmovl_s16(vld1_s16(&g_t32[k][0]));
-            int32x4_t s = vmulq_s32(c, VEEO);
+            int32x4_t t0 = vmulq_s32(c, EEO[i + 0]);
+            int32x4_t t1 = vmulq_s32(c, EEO[i + 1]);
+            int32x4_t t2 = vmulq_s32(c, EEO[i + 2]);
+            int32x4_t t3 = vmulq_s32(c, EEO[i + 3]);
 
-            dst[k * line] = (int16_t)((vaddvq_s32(s) + add) >> shift);
+            int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3));
+            int16x4_t res = vrshrn_n_s32(t, shift);
+            vst1_s16(d, res);
+
+            d += 4;
         }
+    }
 
-        /* EEEE and EEEO */
-        EEEE[0] = VEEE[0] + VEEE[3];
-        EEEO[0] = VEEE[0] - VEEE[3];
-        EEEE[1] = VEEE[1] + VEEE[2];
-        EEEO[1] = VEEE[1] - VEEE[2];
+    int32x4_t c0 = vld1q_s32(t8_even[0]);
+    int32x4_t c8 = vld1q_s32(t8_even[1]);
+    int32x4_t c16 = vld1q_s32(t8_even[2]);
+    int32x4_t c24 = vld1q_s32(t8_even[3]);
 
-        dst[0] = (int16_t)((g_t32[0][0] * EEEE[0] + g_t32[0][1] * EEEE[1] + add) >> shift);
-        dst[16 * line] = (int16_t)((g_t32[16][0] * EEEE[0] + g_t32[16][1] * EEEE[1] + add) >> shift);
-        dst[8 * line] = (int16_t)((g_t32[8][0] * EEEO[0] + g_t32[8][1] * EEEO[1] + add) >> shift);
-        dst[24 * line] = (int16_t)((g_t32[24][0] * EEEO[0] + g_t32[24][1] * EEEO[1] + add) >> shift);
+    for (int i = 0; i < line; i += 4)
+    {
+        int32x4_t t0 = vpaddq_s32(EEEE[i / 2 + 0], EEEE[i / 2 + 1]);
+        int32x4_t t1 = vmulq_s32(c0, t0);
+        int16x4_t res0 = vrshrn_n_s32(t1, shift);
+        vst1_s16(dst + 0 * line, res0);
+
+        int32x4_t t2 = vmulq_s32(c8, EEEO[i / 2 + 0]);
+        int32x4_t t3 = vmulq_s32(c8, EEEO[i / 2 + 1]);
+        int16x4_t res8 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift);
+        vst1_s16(dst + 8 * line, res8);
 
+        int32x4_t t4 = vmulq_s32(c16, EEEE[i / 2 + 0]);
+        int32x4_t t5 = vmulq_s32(c16, EEEE[i / 2 + 1]);
+        int16x4_t res16 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift);
+        vst1_s16(dst + 16 * line, res16);
 
+        int32x4_t t6 = vmulq_s32(c24, EEEO[i / 2 + 0]);
+        int32x4_t t7 = vmulq_s32(c24, EEEO[i / 2 + 1]);
+        int16x4_t res24 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift);
+        vst1_s16(dst + 24 * line, res24);
 
-        src += 32;
-        dst++;
+        dst += 4;
     }
 }
 
@@ -965,8 +1058,8 @@ void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
 
 void dct32_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
 {
-    const int shift_1st = 4 + X265_DEPTH - 8;
-    const int shift_2nd = 11;
+    const int shift_pass1 = 4 + X265_DEPTH - 8;
+    const int shift_pass2 = 11;
 
     ALIGN_VAR_32(int16_t, coef[32 * 32]);
     ALIGN_VAR_32(int16_t, block[32 * 32]);
@@ -976,8 +1069,8 @@ void dct32_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
         memcpy(&block[i * 32], &src[i * srcStride], 32 * sizeof(int16_t));
     }
 
-    partialButterfly32(block, coef, shift_1st, 32);
-    partialButterfly32(coef, dst, shift_2nd, 32);
+    partialButterfly32_neon<shift_pass1>(block, coef);
+    partialButterfly32_neon<shift_pass2>(coef, dst);
 }
 
 void idct4_neon(const int16_t *src, int16_t *dst, intptr_t dstStride)
-- 
2.42.1

-------------- next part --------------
A non-text attachment was scrubbed...
Name: v2-0005-AArch64-Optimise-partialButterfly32_neon.patch
Type: text/x-patch
Size: 12053 bytes
Desc: not available
URL: <http://mailman.videolan.org/pipermail/x265-devel/attachments/20240827/82855735/attachment-0001.bin>


More information about the x265-devel mailing list