[x265] [PATCH v2 3/9] AArch64: Optimise partialButterfly8_neon

Hari Limaye hari.limaye at arm.com
Tue Aug 27 15:10:40 UTC 2024


Optimise the Neon implementation of partialButterfly8 to actually use
Neon operations rather than scalar code, processing four lines at a time
to make use of the full width of Neon vector registers. Avoid widening
to 32-bit values where possible, and replace the addition of a rounding
constant with rounding shift instructions.

Relative performance observed compared to the existing implementation:

  Neoverse N1: 2.69x
  Neoverse V1: 4.57x
  Neoverse N2: 2.87x
  Neoverse V2: 4.87x

Co-authored-by: Jonathan Wright <jonathan.wright at arm.com>
---
 source/common/aarch64/dct-prim.cpp | 136 +++++++++++++++++++++--------
 1 file changed, 99 insertions(+), 37 deletions(-)

diff --git a/source/common/aarch64/dct-prim.cpp b/source/common/aarch64/dct-prim.cpp
index 522210689..e07872157 100644
--- a/source/common/aarch64/dct-prim.cpp
+++ b/source/common/aarch64/dct-prim.cpp
@@ -18,6 +18,15 @@ namespace
 {
 using namespace X265_NS;
 
+// First two columns of the 4x4 dct transform matrix, duplicated to 4x4 to allow
+// processing two lines at once.
+const int32_t t8_even[4][4] =
+{
+    { 64,  64, 64,  64 },
+    { 83,  36, 83,  36 },
+    { 64, -64, 64, -64 },
+    { 36, -83, 36, -83 },
+};
 
 static int16x8_t rev16(const int16x8_t a)
 {
@@ -398,43 +407,96 @@ static void partialButterfly32(const int16_t *src, int16_t *dst, int shift, int
     }
 }
 
-static void partialButterfly8(const int16_t *src, int16_t *dst, int shift, int line)
+template<int shift>
+static inline void partialButterfly8_neon(const int16_t *src, int16_t *dst)
 {
-    int j, k;
-    int E[4], O[4];
-    int EE[2], EO[2];
-    int add = 1 << (shift - 1);
+    const int line = 8;
 
-    for (j = 0; j < line; j++)
+    int16x4_t O[line];
+    int32x4_t EE[line / 2];
+    int32x4_t EO[line / 2];
+
+    for (int i = 0; i < line; i += 2)
     {
-        /* E and O*/
-        for (k = 0; k < 4; k++)
-        {
-            E[k] = src[k] + src[7 - k];
-            O[k] = src[k] - src[7 - k];
-        }
+        int16x4_t s0_lo = vld1_s16(src + i * line);
+        int16x4_t s0_hi = vrev64_s16(vld1_s16(src + i * line + 4));
 
-        /* EE and EO */
-        EE[0] = E[0] + E[3];
-        EO[0] = E[0] - E[3];
-        EE[1] = E[1] + E[2];
-        EO[1] = E[1] - E[2];
-
-        dst[0] = (int16_t)((g_t8[0][0] * EE[0] + g_t8[0][1] * EE[1] + add) >> shift);
-        dst[4 * line] = (int16_t)((g_t8[4][0] * EE[0] + g_t8[4][1] * EE[1] + add) >> shift);
-        dst[2 * line] = (int16_t)((g_t8[2][0] * EO[0] + g_t8[2][1] * EO[1] + add) >> shift);
-        dst[6 * line] = (int16_t)((g_t8[6][0] * EO[0] + g_t8[6][1] * EO[1] + add) >> shift);
-
-        dst[line] = (int16_t)((g_t8[1][0] * O[0] + g_t8[1][1] * O[1] + g_t8[1][2] * O[2] + g_t8[1][3] * O[3] + add) >> shift);
-        dst[3 * line] = (int16_t)((g_t8[3][0] * O[0] + g_t8[3][1] * O[1] + g_t8[3][2] * O[2] + g_t8[3][3] * O[3] + add) >>
-                                  shift);
-        dst[5 * line] = (int16_t)((g_t8[5][0] * O[0] + g_t8[5][1] * O[1] + g_t8[5][2] * O[2] + g_t8[5][3] * O[3] + add) >>
-                                  shift);
-        dst[7 * line] = (int16_t)((g_t8[7][0] * O[0] + g_t8[7][1] * O[1] + g_t8[7][2] * O[2] + g_t8[7][3] * O[3] + add) >>
-                                  shift);
-
-        src += 8;
-        dst++;
+        int16x4_t s1_lo = vld1_s16(src + (i + 1) * line);
+        int16x4_t s1_hi = vrev64_s16(vld1_s16(src + (i + 1) * line + 4));
+
+        int32x4_t E0 = vaddl_s16(s0_lo, s0_hi);
+        int32x4_t E1 = vaddl_s16(s1_lo, s1_hi);
+
+        O[i + 0] = vsub_s16(s0_lo, s0_hi);
+        O[i + 1] = vsub_s16(s1_lo, s1_hi);
+
+        int32x4_t t0 = vreinterpretq_s32_s64(
+            vzip1q_s64(vreinterpretq_s64_s32(E0), vreinterpretq_s64_s32(E1)));
+        int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64(
+            vzip2q_s64(vreinterpretq_s64_s32(E0), vreinterpretq_s64_s32(E1))));
+
+        EE[i / 2] = vaddq_s32(t0, t1);
+        EO[i / 2] = vsubq_s32(t0, t1);
+    }
+
+    int16_t *d = dst;
+
+    int32x4_t c0 = vld1q_s32(t8_even[0]);
+    int32x4_t c2 = vld1q_s32(t8_even[1]);
+    int32x4_t c4 = vld1q_s32(t8_even[2]);
+    int32x4_t c6 = vld1q_s32(t8_even[3]);
+    int16x4_t c1 = vld1_s16(g_t8[1]);
+    int16x4_t c3 = vld1_s16(g_t8[3]);
+    int16x4_t c5 = vld1_s16(g_t8[5]);
+    int16x4_t c7 = vld1_s16(g_t8[7]);
+
+    for (int j = 0; j < line; j += 4)
+    {
+        // O
+        int32x4_t t01 = vpaddq_s32(vmull_s16(c1, O[j + 0]),
+                                   vmull_s16(c1, O[j + 1]));
+        int32x4_t t23 = vpaddq_s32(vmull_s16(c1, O[j + 2]),
+                                   vmull_s16(c1, O[j + 3]));
+        int16x4_t res1 = vrshrn_n_s32(vpaddq_s32(t01, t23), shift);
+        vst1_s16(d + 1 * line, res1);
+
+        t01 = vpaddq_s32(vmull_s16(c3, O[j + 0]), vmull_s16(c3, O[j + 1]));
+        t23 = vpaddq_s32(vmull_s16(c3, O[j + 2]), vmull_s16(c3, O[j + 3]));
+        int16x4_t res3 = vrshrn_n_s32(vpaddq_s32(t01, t23), shift);
+        vst1_s16(d + 3 * line, res3);
+
+        t01 = vpaddq_s32(vmull_s16(c5, O[j + 0]), vmull_s16(c5, O[j + 1]));
+        t23 = vpaddq_s32(vmull_s16(c5, O[j + 2]), vmull_s16(c5, O[j + 3]));
+        int16x4_t res5 = vrshrn_n_s32(vpaddq_s32(t01, t23), shift);
+        vst1_s16(d + 5 * line, res5);
+
+        t01 = vpaddq_s32(vmull_s16(c7, O[j + 0]), vmull_s16(c7, O[j + 1]));
+        t23 = vpaddq_s32(vmull_s16(c7, O[j + 2]), vmull_s16(c7, O[j + 3]));
+        int16x4_t res7 = vrshrn_n_s32(vpaddq_s32(t01, t23), shift);
+        vst1_s16(d + 7 * line, res7);
+
+        // EE and EO
+        int32x4_t t0 = vpaddq_s32(EE[j / 2 + 0], EE[j / 2 + 1]);
+        int32x4_t t1 = vmulq_s32(c0, t0);
+        int16x4_t res0 = vrshrn_n_s32(t1, shift);
+        vst1_s16(d + 0 * line, res0);
+
+        int32x4_t t2 = vmulq_s32(c2, EO[j / 2 + 0]);
+        int32x4_t t3 = vmulq_s32(c2, EO[j / 2 + 1]);
+        int16x4_t res2 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift);
+        vst1_s16(d + 2 * line, res2);
+
+        int32x4_t t4 = vmulq_s32(c4, EE[j / 2 + 0]);
+        int32x4_t t5 = vmulq_s32(c4, EE[j / 2 + 1]);
+        int16x4_t res4 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift);
+        vst1_s16(d + 4 * line, res4);
+
+        int32x4_t t6 = vmulq_s32(c6, EO[j / 2 + 0]);
+        int32x4_t t7 = vmulq_s32(c6, EO[j / 2 + 1]);
+        int16x4_t res6 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift);
+        vst1_s16(d + 6 * line, res6);
+
+        d += 4;
     }
 }
 
@@ -819,8 +881,8 @@ namespace X265_NS
 // x265 private namespace
 void dct8_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
 {
-    const int shift_1st = 2 + X265_DEPTH - 8;
-    const int shift_2nd = 9;
+    const int shift_pass1 = 2 + X265_DEPTH - 8;
+    const int shift_pass2 = 9;
 
     ALIGN_VAR_32(int16_t, coef[8 * 8]);
     ALIGN_VAR_32(int16_t, block[8 * 8]);
@@ -830,8 +892,8 @@ void dct8_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
         memcpy(&block[i * 8], &src[i * srcStride], 8 * sizeof(int16_t));
     }
 
-    partialButterfly8(block, coef, shift_1st, 8);
-    partialButterfly8(coef, dst, shift_2nd, 8);
+    partialButterfly8_neon<shift_pass1>(block, coef);
+    partialButterfly8_neon<shift_pass2>(coef, dst);
 }
 
 void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride)
-- 
2.42.1

-------------- next part --------------
A non-text attachment was scrubbed...
Name: v2-0003-AArch64-Optimise-partialButterfly8_neon.patch
Type: text/x-patch
Size: 7708 bytes
Desc: not available
URL: <http://mailman.videolan.org/pipermail/x265-devel/attachments/20240827/5e9e30d4/attachment-0001.bin>


More information about the x265-devel mailing list