[x265] [PATCH 2/2] AArch64: Optimise interp8_horiz_ps_i8mm when coeff == 2

Gerda Zsejke More gerdazsejke.more at arm.com
Thu Apr 24 10:18:01 UTC 2025


To avoid computing the same value twice in the USMMLA instruction, we
need to use a staggered filter with 7 taps or fewer. We can do this
easily for filters 1 and 3, as one tap is 0.

In order to take advantage of the matrix multiply instruction for
filter coefficient equal to 2, we can adjust the implementation by
substracting the source elements corresponding to filter value -1,
and for the remaining 7 filter values use the USMMLA instruction.
---
 source/common/aarch64/filter-neon-i8mm.cpp | 295 +++++----------------
 1 file changed, 59 insertions(+), 236 deletions(-)

diff --git a/source/common/aarch64/filter-neon-i8mm.cpp b/source/common/aarch64/filter-neon-i8mm.cpp
index d94660764..93544c5d4 100644
--- a/source/common/aarch64/filter-neon-i8mm.cpp
+++ b/source/common/aarch64/filter-neon-i8mm.cpp
@@ -60,17 +60,6 @@ static const uint8_t dot_prod_merge_block_tbl[48] = {
     3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
 };
 
-void inline init_sample_permute(uint8x8_t *samples, const uint8x16x3_t tbl,
-                                uint8x16_t *d)
-{
-    // Permute input samples for dot product.
-    // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
-    d[0] = vqtbl1q_u8(vcombine_u8(samples[0], vdup_n_u8(0)), tbl.val[0]);
-    d[1] = vqtbl1q_u8(vcombine_u8(samples[1], vdup_n_u8(0)), tbl.val[0]);
-    d[2] = vqtbl1q_u8(vcombine_u8(samples[2], vdup_n_u8(0)), tbl.val[0]);
-    d[3] = vqtbl1q_u8(vcombine_u8(samples[3], vdup_n_u8(0)), tbl.val[0]);
-}
-
 template<bool coeff2>
 uint8x8_t inline filter8_8_pp_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const uint8x16x2_t tbl)
@@ -96,70 +85,7 @@ uint8x8_t inline filter8_8_pp_matmul(uint8x16_t samples, const int8x16_t filter,
     return vqrshrun_n_s16(matmul, IF_FILTER_PREC);
 }
 
-int16x4_t inline filter8_4_ps(uint8x16_t samples, const int8x8_t filter,
-                              const int16x8_t constant, const uint8x16x3_t tbl)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    uint8x16_t perm_s0 = vqtbl1q_u8(samples, tbl.val[0]);
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-
-    int32x4_t dotprod = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod = vusdotq_lane_s32(dotprod, perm_s1, filter, 1);
-
-    // Narrow.
-    return vadd_s16(vmovn_s32(dotprod), vget_low_s16(constant));
-}
-
-int16x8_t inline filter8_8_ps(uint8x16_t samples, const int8x8_t filter,
-                              const int16x8_t constant, const uint8x16x3_t tbl)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    uint8x16_t perm_s0 = vqtbl1q_u8(samples, tbl.val[0]);
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-    // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-    uint8x16_t perm_S2 = vqtbl1q_u8(samples, tbl.val[2]);
-
-    int32x4_t dotprod_lo = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod_lo = vusdotq_lane_s32(dotprod_lo, perm_s1, filter, 1);
-    int32x4_t dotprod_hi = vusdotq_lane_s32(vdupq_n_s32(0), perm_s1, filter, 0);
-    dotprod_hi = vusdotq_lane_s32(dotprod_hi, perm_S2, filter, 1);
-
-    // Narrow and combine.
-    int16x8_t dotprod = vcombine_s16(vmovn_s32(dotprod_lo),
-                                     vmovn_s32(dotprod_hi));
-    return vaddq_s16(dotprod, constant);
-}
-
-int16x8_t inline filter8_8_ps_reuse(uint8x16_t samples, const int8x8_t filter,
-                                    const int16x8_t constant,
-                                    const uint8x16x3_t tbl, uint8x16_t &perm_s0)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    // Already in perm_s0.
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-    // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-    uint8x16_t perm_s2 = vqtbl1q_u8(samples, tbl.val[2]);
-
-    int32x4_t dotprod_lo = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod_lo = vusdotq_lane_s32(dotprod_lo, perm_s1, filter, 1);
-    int32x4_t dotprod_hi = vusdotq_lane_s32(vdupq_n_s32(0), perm_s1, filter, 0);
-    dotprod_hi = vusdotq_lane_s32(dotprod_hi, perm_s2, filter, 1);
-
-    // Save for re-use in next iteration.
-    perm_s0 = perm_s2;
-
-    // Narrow and combine.
-    int16x8_t dotprod = vcombine_s16(vmovn_s32(dotprod_lo),
-                                     vmovn_s32(dotprod_hi));
-    return vaddq_s16(dotprod, constant);
-}
-
+template<bool coeff2>
 int16x8_t inline filter8_8_ps_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const int16x8_t constant,
                                      const uint8x16x2_t tbl)
@@ -173,9 +99,21 @@ int16x8_t inline filter8_8_ps_matmul(uint8x16_t samples, const int8x16_t filter,
 
     // Narrow and combine.
     int16x8_t matmul = vcombine_s16(vmovn_s32(matmul_lo), vmovn_s32(matmul_hi));
-    return vaddq_s16(matmul, constant);
+
+    int16x8_t offset_matmul = constant;
+
+    if (coeff2)
+    {
+        // Substract the source elements corresponding to filter tap value -1,
+        // which weren't included in the initial matrix multiplication.
+        offset_matmul = vreinterpretq_s16_u16(
+            vsubw_u8(vreinterpretq_u16_s16(offset_matmul), vget_low_u8(samples)));
+    }
+
+    return vaddq_s16(matmul, offset_matmul);
 }
 
+template<bool coeff2>
 int16x4_t inline filter8_4_ps_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const int16x8_t constant,
                                      const uint8x16x2_t tbl)
@@ -185,7 +123,17 @@ int16x4_t inline filter8_4_ps_matmul(uint8x16_t samples, const int8x16_t filter,
 
     int32x4_t matmul = vusmmlaq_s32(vdupq_n_s32(0), perm, filter);
 
-    return vadd_s16(vmovn_s32(matmul), vget_low_s16(constant));
+    int16x8_t offset_matmul = constant;
+
+    if (coeff2)
+    {
+        // Substract the source elements corresponding to filter tap value -1,
+        // which weren't included in the initial matrix multiplication.
+        offset_matmul = vreinterpretq_s16_u16(
+            vsubw_u8(vreinterpretq_u16_s16(offset_matmul), vget_low_u8(samples)));
+    }
+
+    return vadd_s16(vmovn_s32(matmul), vget_low_s16(offset_matmul));
 }
 
 uint8x8_t inline filter4_8_pp(uint8x16_t samples, const int8x8_t filter,
@@ -385,134 +333,16 @@ void interp8_horiz_pp_i8mm(const uint8_t *src, intptr_t srcStride, uint8_t *dst,
     }
 }
 
-template<int width, int height>
-void inline interp8_horiz_ps_dotprod(const uint8_t *src, intptr_t srcStride,
-                                     int16_t *dst, intptr_t dstStride,
-                                     int coeffIdx, int isRowExt)
-{
-    const int offset = (unsigned)-IF_INTERNAL_OFFS;
-
-    const int N_TAPS = 8;
-    int blkheight = height;
-
-    src -= N_TAPS / 2 - 1;
-    if (isRowExt)
-    {
-        src -= (N_TAPS / 2 - 1) * srcStride;
-        blkheight += N_TAPS - 1;
-    }
-
-    const uint8x16x3_t tbl = vld1q_u8_x3(dotprod_permute_tbl);
-    const int8x8_t filter = vmovn_s16(vld1q_s16(g_lumaFilter[coeffIdx]));
-    const int16x8_t c = vdupq_n_s16(offset);
-
-    for (int row = 0; row + 4 <= blkheight; row += 4)
-    {
-        int col = 0;
-        if (width >= 32)
-        {
-            // Peel first sample permute to enable passing between iterations.
-            uint8x8_t s0[4];
-            load_u8x8xn<4>(src, srcStride, s0);
-            uint8x16_t ps0[4];
-            init_sample_permute(s0, tbl, ps0);
-
-            for (; col + 16 <= width; col += 16)
-            {
-                uint8x16_t s_lo[4], s_hi[4];
-                load_u8x16xn<4>(src + col + 0, srcStride, s_lo);
-                load_u8x16xn<4>(src + col + 8, srcStride, s_hi);
-
-                int16x8_t d_lo[4];
-                d_lo[0] = filter8_8_ps_reuse(s_lo[0], filter, c, tbl, ps0[0]);
-                d_lo[1] = filter8_8_ps_reuse(s_lo[1], filter, c, tbl, ps0[1]);
-                d_lo[2] = filter8_8_ps_reuse(s_lo[2], filter, c, tbl, ps0[2]);
-                d_lo[3] = filter8_8_ps_reuse(s_lo[3], filter, c, tbl, ps0[3]);
-
-                int16x8_t d_hi[4];
-                d_hi[0] = filter8_8_ps_reuse(s_hi[0], filter, c, tbl, ps0[0]);
-                d_hi[1] = filter8_8_ps_reuse(s_hi[1], filter, c, tbl, ps0[1]);
-                d_hi[2] = filter8_8_ps_reuse(s_hi[2], filter, c, tbl, ps0[2]);
-                d_hi[3] = filter8_8_ps_reuse(s_hi[3], filter, c, tbl, ps0[3]);
-
-                store_s16x8xn<4>(dst + col + 0, dstStride, d_lo);
-                store_s16x8xn<4>(dst + col + 8, dstStride, d_hi);
-            }
-        }
-        else
-        {
-            for (; col + 8 <= width; col += 8)
-            {
-                uint8x16_t s[4];
-                load_u8x16xn<4>(src + col, srcStride, s);
-
-                int16x8_t d[4];
-                d[0] = filter8_8_ps(s[0], filter, c, tbl);
-                d[1] = filter8_8_ps(s[1], filter, c, tbl);
-                d[2] = filter8_8_ps(s[2], filter, c, tbl);
-                d[3] = filter8_8_ps(s[3], filter, c, tbl);
-
-                store_s16x8xn<4>(dst + col, dstStride, d);
-            }
-        }
-        for (; col < width; col += 4)
-        {
-            uint8x16_t s[4];
-            load_u8x16xn<4>(src + col, srcStride, s);
-
-            int16x4_t d[4];
-            d[0] = filter8_4_ps(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps(s[2], filter, c, tbl);
-            d[3] = filter8_4_ps(s[3], filter, c, tbl);
-
-            store_s16x4xn<4>(dst + col, dstStride, d);
-        }
-
-        src += 4 * srcStride;
-        dst += 4 * dstStride;
-    }
-
-    if (isRowExt)
-    {
-        // process final 3 rows
-        int col = 0;
-        for (; (col + 8) <= width; col += 8)
-        {
-            uint8x16_t s[3];
-            load_u8x16xn<3>(src + col, srcStride, s);
-
-            int16x8_t d[3];
-            d[0] = filter8_8_ps(s[0], filter, c, tbl);
-            d[1] = filter8_8_ps(s[1], filter, c, tbl);
-            d[2] = filter8_8_ps(s[2], filter, c, tbl);
-
-            store_s16x8xn<3>(dst + col, dstStride, d);
-        }
-
-        for (; col < width; col += 4)
-        {
-            uint8x16_t s[3];
-            load_u8x16xn<3>(src + col, srcStride, s);
-
-            int16x4_t d[3];
-            d[0] = filter8_4_ps(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps(s[2], filter, c, tbl);
-
-            store_s16x4xn<3>(dst + col, dstStride, d);
-        }
-    }
-}
-
-template<int coeffIdx, int width, int height>
+template<bool coeff2, int width, int height>
 void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                                     int16_t *dst, intptr_t dstStride,
-                                    int isRowExt)
+                                    int coeffIdx, int isRowExt)
 {
     const int offset = (unsigned)-IF_INTERNAL_OFFS;
-
     const int N_TAPS = 8;
+    const uint8x16x2_t tbl = vld1q_u8_x2(matmul_permute_tbl[coeffIdx >> 1]);
+    const int8x16_t filter = vld1q_s8(matmul_luma_filter[coeffIdx - 1]);
+    const int16x8_t c = vdupq_n_s16(offset);
     int blkheight = height;
 
     src -= N_TAPS / 2 - 1;
@@ -522,11 +352,6 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
         blkheight += N_TAPS - 1;
     }
 
-    const uint8x16x2_t tbl = vld1q_u8_x2(matmul_permute_tbl[coeffIdx >> 1]);
-    const int8x16_t filter = vld1q_s8(matmul_luma_filter[coeffIdx - 1]);
-
-    const int16x8_t c = vdupq_n_s16(offset);
-
     for (int row = 0; row + 4 <= blkheight; row += 4)
     {
         int col = 0;
@@ -539,16 +364,16 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                 load_u8x16xn<4>(src + col + 8, srcStride, s_hi);
 
                 int16x8_t d_lo[4];
-                d_lo[0] = filter8_8_ps_matmul(s_lo[0], filter, c, tbl);
-                d_lo[1] = filter8_8_ps_matmul(s_lo[1], filter, c, tbl);
-                d_lo[2] = filter8_8_ps_matmul(s_lo[2], filter, c, tbl);
-                d_lo[3] = filter8_8_ps_matmul(s_lo[3], filter, c, tbl);
+                d_lo[0] = filter8_8_ps_matmul<coeff2>(s_lo[0], filter, c, tbl);
+                d_lo[1] = filter8_8_ps_matmul<coeff2>(s_lo[1], filter, c, tbl);
+                d_lo[2] = filter8_8_ps_matmul<coeff2>(s_lo[2], filter, c, tbl);
+                d_lo[3] = filter8_8_ps_matmul<coeff2>(s_lo[3], filter, c, tbl);
 
                 int16x8_t d_hi[4];
-                d_hi[0] = filter8_8_ps_matmul(s_hi[0], filter, c, tbl);
-                d_hi[1] = filter8_8_ps_matmul(s_hi[1], filter, c, tbl);
-                d_hi[2] = filter8_8_ps_matmul(s_hi[2], filter, c, tbl);
-                d_hi[3] = filter8_8_ps_matmul(s_hi[3], filter, c, tbl);
+                d_hi[0] = filter8_8_ps_matmul<coeff2>(s_hi[0], filter, c, tbl);
+                d_hi[1] = filter8_8_ps_matmul<coeff2>(s_hi[1], filter, c, tbl);
+                d_hi[2] = filter8_8_ps_matmul<coeff2>(s_hi[2], filter, c, tbl);
+                d_hi[3] = filter8_8_ps_matmul<coeff2>(s_hi[3], filter, c, tbl);
 
                 store_s16x8xn<4>(dst + col + 0, dstStride, d_lo);
                 store_s16x8xn<4>(dst + col + 8, dstStride, d_hi);
@@ -562,10 +387,10 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                 load_u8x16xn<4>(src + col, srcStride, s);
 
                 int16x8_t d[4];
-                d[0] = filter8_8_ps_matmul(s[0], filter, c, tbl);
-                d[1] = filter8_8_ps_matmul(s[1], filter, c, tbl);
-                d[2] = filter8_8_ps_matmul(s[2], filter, c, tbl);
-                d[3] = filter8_8_ps_matmul(s[3], filter, c, tbl);
+                d[0] = filter8_8_ps_matmul<coeff2>(s[0], filter, c, tbl);
+                d[1] = filter8_8_ps_matmul<coeff2>(s[1], filter, c, tbl);
+                d[2] = filter8_8_ps_matmul<coeff2>(s[2], filter, c, tbl);
+                d[3] = filter8_8_ps_matmul<coeff2>(s[3], filter, c, tbl);
 
                 store_s16x8xn<4>(dst + col, dstStride, d);
             }
@@ -576,10 +401,10 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<4>(src + col, srcStride, s);
 
             int16x4_t d[4];
-            d[0] = filter8_4_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps_matmul(s[2], filter, c, tbl);
-            d[3] = filter8_4_ps_matmul(s[3], filter, c, tbl);
+            d[0] = filter8_4_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_4_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_4_ps_matmul<coeff2>(s[2], filter, c, tbl);
+            d[3] = filter8_4_ps_matmul<coeff2>(s[3], filter, c, tbl);
 
             store_s16x4xn<4>(dst + col, dstStride, d);
         }
@@ -598,9 +423,9 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<3>(src + col, srcStride, s);
 
             int16x8_t d[3];
-            d[0] = filter8_8_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_8_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_8_ps_matmul(s[2], filter, c, tbl);
+            d[0] = filter8_8_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_8_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_8_ps_matmul<coeff2>(s[2], filter, c, tbl);
 
             store_s16x8xn<3>(dst + col, dstStride, d);
         }
@@ -611,9 +436,9 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<3>(src + col, srcStride, s);
 
             int16x4_t d[3];
-            d[0] = filter8_4_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps_matmul(s[2], filter, c, tbl);
+            d[0] = filter8_4_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_4_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_4_ps_matmul<coeff2>(s[2], filter, c, tbl);
 
             store_s16x4xn<3>(dst + col, dstStride, d);
         }
@@ -626,16 +451,14 @@ void interp8_horiz_ps_i8mm(const uint8_t *src, intptr_t srcStride, int16_t *dst,
 {
     switch (coeffIdx)
     {
-    case 1:
-        return interp8_horiz_ps_matmul<1, width, height>(src, srcStride, dst,
-                                                         dstStride, isRowExt);
     case 2:
-        return interp8_horiz_ps_dotprod<width, height>(src, srcStride, dst,
-                                                       dstStride, coeffIdx,
-                                                       isRowExt);
-    case 3:
-        return interp8_horiz_ps_matmul<3, width, height>(src, srcStride, dst,
-                                                         dstStride, isRowExt);
+        return interp8_horiz_ps_matmul<true, width, height>(src, srcStride, dst,
+                                                            dstStride, coeffIdx,
+                                                            isRowExt);
+    default:
+        return interp8_horiz_ps_matmul<false, width, height>(src, srcStride, dst,
+                                                             dstStride, coeffIdx,
+                                                             isRowExt);
     }
 }
 
-- 
2.39.5 (Apple Git-154)

-------------- next part --------------
>From b1925efa5db502a6d70c13538fd2fb05a2508ec4 Mon Sep 17 00:00:00 2001
Message-Id: <b1925efa5db502a6d70c13538fd2fb05a2508ec4.1745489546.git.gerdazsejke.more at arm.com>
In-Reply-To: <cover.1745489546.git.gerdazsejke.more at arm.com>
References: <cover.1745489546.git.gerdazsejke.more at arm.com>
From: Gerda Zsejke More <gerdazsejke.more at arm.com>
Date: Wed, 16 Apr 2025 16:06:15 +0200
Subject: [PATCH 2/2] AArch64: Optimise interp8_horiz_ps_i8mm when coeff == 2

To avoid computing the same value twice in the USMMLA instruction, we
need to use a staggered filter with 7 taps or fewer. We can do this
easily for filters 1 and 3, as one tap is 0.

In order to take advantage of the matrix multiply instruction for
filter coefficient equal to 2, we can adjust the implementation by
substracting the source elements corresponding to filter value -1,
and for the remaining 7 filter values use the USMMLA instruction.
---
 source/common/aarch64/filter-neon-i8mm.cpp | 295 +++++----------------
 1 file changed, 59 insertions(+), 236 deletions(-)

diff --git a/source/common/aarch64/filter-neon-i8mm.cpp b/source/common/aarch64/filter-neon-i8mm.cpp
index d94660764..93544c5d4 100644
--- a/source/common/aarch64/filter-neon-i8mm.cpp
+++ b/source/common/aarch64/filter-neon-i8mm.cpp
@@ -60,17 +60,6 @@ static const uint8_t dot_prod_merge_block_tbl[48] = {
     3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
 };
 
-void inline init_sample_permute(uint8x8_t *samples, const uint8x16x3_t tbl,
-                                uint8x16_t *d)
-{
-    // Permute input samples for dot product.
-    // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
-    d[0] = vqtbl1q_u8(vcombine_u8(samples[0], vdup_n_u8(0)), tbl.val[0]);
-    d[1] = vqtbl1q_u8(vcombine_u8(samples[1], vdup_n_u8(0)), tbl.val[0]);
-    d[2] = vqtbl1q_u8(vcombine_u8(samples[2], vdup_n_u8(0)), tbl.val[0]);
-    d[3] = vqtbl1q_u8(vcombine_u8(samples[3], vdup_n_u8(0)), tbl.val[0]);
-}
-
 template<bool coeff2>
 uint8x8_t inline filter8_8_pp_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const uint8x16x2_t tbl)
@@ -96,70 +85,7 @@ uint8x8_t inline filter8_8_pp_matmul(uint8x16_t samples, const int8x16_t filter,
     return vqrshrun_n_s16(matmul, IF_FILTER_PREC);
 }
 
-int16x4_t inline filter8_4_ps(uint8x16_t samples, const int8x8_t filter,
-                              const int16x8_t constant, const uint8x16x3_t tbl)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    uint8x16_t perm_s0 = vqtbl1q_u8(samples, tbl.val[0]);
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-
-    int32x4_t dotprod = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod = vusdotq_lane_s32(dotprod, perm_s1, filter, 1);
-
-    // Narrow.
-    return vadd_s16(vmovn_s32(dotprod), vget_low_s16(constant));
-}
-
-int16x8_t inline filter8_8_ps(uint8x16_t samples, const int8x8_t filter,
-                              const int16x8_t constant, const uint8x16x3_t tbl)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    uint8x16_t perm_s0 = vqtbl1q_u8(samples, tbl.val[0]);
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-    // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-    uint8x16_t perm_S2 = vqtbl1q_u8(samples, tbl.val[2]);
-
-    int32x4_t dotprod_lo = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod_lo = vusdotq_lane_s32(dotprod_lo, perm_s1, filter, 1);
-    int32x4_t dotprod_hi = vusdotq_lane_s32(vdupq_n_s32(0), perm_s1, filter, 0);
-    dotprod_hi = vusdotq_lane_s32(dotprod_hi, perm_S2, filter, 1);
-
-    // Narrow and combine.
-    int16x8_t dotprod = vcombine_s16(vmovn_s32(dotprod_lo),
-                                     vmovn_s32(dotprod_hi));
-    return vaddq_s16(dotprod, constant);
-}
-
-int16x8_t inline filter8_8_ps_reuse(uint8x16_t samples, const int8x8_t filter,
-                                    const int16x8_t constant,
-                                    const uint8x16x3_t tbl, uint8x16_t &perm_s0)
-{
-    // Permute input samples for dot product.
-    // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-    // Already in perm_s0.
-    // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-    uint8x16_t perm_s1 = vqtbl1q_u8(samples, tbl.val[1]);
-    // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-    uint8x16_t perm_s2 = vqtbl1q_u8(samples, tbl.val[2]);
-
-    int32x4_t dotprod_lo = vusdotq_lane_s32(vdupq_n_s32(0), perm_s0, filter, 0);
-    dotprod_lo = vusdotq_lane_s32(dotprod_lo, perm_s1, filter, 1);
-    int32x4_t dotprod_hi = vusdotq_lane_s32(vdupq_n_s32(0), perm_s1, filter, 0);
-    dotprod_hi = vusdotq_lane_s32(dotprod_hi, perm_s2, filter, 1);
-
-    // Save for re-use in next iteration.
-    perm_s0 = perm_s2;
-
-    // Narrow and combine.
-    int16x8_t dotprod = vcombine_s16(vmovn_s32(dotprod_lo),
-                                     vmovn_s32(dotprod_hi));
-    return vaddq_s16(dotprod, constant);
-}
-
+template<bool coeff2>
 int16x8_t inline filter8_8_ps_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const int16x8_t constant,
                                      const uint8x16x2_t tbl)
@@ -173,9 +99,21 @@ int16x8_t inline filter8_8_ps_matmul(uint8x16_t samples, const int8x16_t filter,
 
     // Narrow and combine.
     int16x8_t matmul = vcombine_s16(vmovn_s32(matmul_lo), vmovn_s32(matmul_hi));
-    return vaddq_s16(matmul, constant);
+
+    int16x8_t offset_matmul = constant;
+
+    if (coeff2)
+    {
+        // Substract the source elements corresponding to filter tap value -1,
+        // which weren't included in the initial matrix multiplication.
+        offset_matmul = vreinterpretq_s16_u16(
+            vsubw_u8(vreinterpretq_u16_s16(offset_matmul), vget_low_u8(samples)));
+    }
+
+    return vaddq_s16(matmul, offset_matmul);
 }
 
+template<bool coeff2>
 int16x4_t inline filter8_4_ps_matmul(uint8x16_t samples, const int8x16_t filter,
                                      const int16x8_t constant,
                                      const uint8x16x2_t tbl)
@@ -185,7 +123,17 @@ int16x4_t inline filter8_4_ps_matmul(uint8x16_t samples, const int8x16_t filter,
 
     int32x4_t matmul = vusmmlaq_s32(vdupq_n_s32(0), perm, filter);
 
-    return vadd_s16(vmovn_s32(matmul), vget_low_s16(constant));
+    int16x8_t offset_matmul = constant;
+
+    if (coeff2)
+    {
+        // Substract the source elements corresponding to filter tap value -1,
+        // which weren't included in the initial matrix multiplication.
+        offset_matmul = vreinterpretq_s16_u16(
+            vsubw_u8(vreinterpretq_u16_s16(offset_matmul), vget_low_u8(samples)));
+    }
+
+    return vadd_s16(vmovn_s32(matmul), vget_low_s16(offset_matmul));
 }
 
 uint8x8_t inline filter4_8_pp(uint8x16_t samples, const int8x8_t filter,
@@ -385,134 +333,16 @@ void interp8_horiz_pp_i8mm(const uint8_t *src, intptr_t srcStride, uint8_t *dst,
     }
 }
 
-template<int width, int height>
-void inline interp8_horiz_ps_dotprod(const uint8_t *src, intptr_t srcStride,
-                                     int16_t *dst, intptr_t dstStride,
-                                     int coeffIdx, int isRowExt)
-{
-    const int offset = (unsigned)-IF_INTERNAL_OFFS;
-
-    const int N_TAPS = 8;
-    int blkheight = height;
-
-    src -= N_TAPS / 2 - 1;
-    if (isRowExt)
-    {
-        src -= (N_TAPS / 2 - 1) * srcStride;
-        blkheight += N_TAPS - 1;
-    }
-
-    const uint8x16x3_t tbl = vld1q_u8_x3(dotprod_permute_tbl);
-    const int8x8_t filter = vmovn_s16(vld1q_s16(g_lumaFilter[coeffIdx]));
-    const int16x8_t c = vdupq_n_s16(offset);
-
-    for (int row = 0; row + 4 <= blkheight; row += 4)
-    {
-        int col = 0;
-        if (width >= 32)
-        {
-            // Peel first sample permute to enable passing between iterations.
-            uint8x8_t s0[4];
-            load_u8x8xn<4>(src, srcStride, s0);
-            uint8x16_t ps0[4];
-            init_sample_permute(s0, tbl, ps0);
-
-            for (; col + 16 <= width; col += 16)
-            {
-                uint8x16_t s_lo[4], s_hi[4];
-                load_u8x16xn<4>(src + col + 0, srcStride, s_lo);
-                load_u8x16xn<4>(src + col + 8, srcStride, s_hi);
-
-                int16x8_t d_lo[4];
-                d_lo[0] = filter8_8_ps_reuse(s_lo[0], filter, c, tbl, ps0[0]);
-                d_lo[1] = filter8_8_ps_reuse(s_lo[1], filter, c, tbl, ps0[1]);
-                d_lo[2] = filter8_8_ps_reuse(s_lo[2], filter, c, tbl, ps0[2]);
-                d_lo[3] = filter8_8_ps_reuse(s_lo[3], filter, c, tbl, ps0[3]);
-
-                int16x8_t d_hi[4];
-                d_hi[0] = filter8_8_ps_reuse(s_hi[0], filter, c, tbl, ps0[0]);
-                d_hi[1] = filter8_8_ps_reuse(s_hi[1], filter, c, tbl, ps0[1]);
-                d_hi[2] = filter8_8_ps_reuse(s_hi[2], filter, c, tbl, ps0[2]);
-                d_hi[3] = filter8_8_ps_reuse(s_hi[3], filter, c, tbl, ps0[3]);
-
-                store_s16x8xn<4>(dst + col + 0, dstStride, d_lo);
-                store_s16x8xn<4>(dst + col + 8, dstStride, d_hi);
-            }
-        }
-        else
-        {
-            for (; col + 8 <= width; col += 8)
-            {
-                uint8x16_t s[4];
-                load_u8x16xn<4>(src + col, srcStride, s);
-
-                int16x8_t d[4];
-                d[0] = filter8_8_ps(s[0], filter, c, tbl);
-                d[1] = filter8_8_ps(s[1], filter, c, tbl);
-                d[2] = filter8_8_ps(s[2], filter, c, tbl);
-                d[3] = filter8_8_ps(s[3], filter, c, tbl);
-
-                store_s16x8xn<4>(dst + col, dstStride, d);
-            }
-        }
-        for (; col < width; col += 4)
-        {
-            uint8x16_t s[4];
-            load_u8x16xn<4>(src + col, srcStride, s);
-
-            int16x4_t d[4];
-            d[0] = filter8_4_ps(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps(s[2], filter, c, tbl);
-            d[3] = filter8_4_ps(s[3], filter, c, tbl);
-
-            store_s16x4xn<4>(dst + col, dstStride, d);
-        }
-
-        src += 4 * srcStride;
-        dst += 4 * dstStride;
-    }
-
-    if (isRowExt)
-    {
-        // process final 3 rows
-        int col = 0;
-        for (; (col + 8) <= width; col += 8)
-        {
-            uint8x16_t s[3];
-            load_u8x16xn<3>(src + col, srcStride, s);
-
-            int16x8_t d[3];
-            d[0] = filter8_8_ps(s[0], filter, c, tbl);
-            d[1] = filter8_8_ps(s[1], filter, c, tbl);
-            d[2] = filter8_8_ps(s[2], filter, c, tbl);
-
-            store_s16x8xn<3>(dst + col, dstStride, d);
-        }
-
-        for (; col < width; col += 4)
-        {
-            uint8x16_t s[3];
-            load_u8x16xn<3>(src + col, srcStride, s);
-
-            int16x4_t d[3];
-            d[0] = filter8_4_ps(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps(s[2], filter, c, tbl);
-
-            store_s16x4xn<3>(dst + col, dstStride, d);
-        }
-    }
-}
-
-template<int coeffIdx, int width, int height>
+template<bool coeff2, int width, int height>
 void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                                     int16_t *dst, intptr_t dstStride,
-                                    int isRowExt)
+                                    int coeffIdx, int isRowExt)
 {
     const int offset = (unsigned)-IF_INTERNAL_OFFS;
-
     const int N_TAPS = 8;
+    const uint8x16x2_t tbl = vld1q_u8_x2(matmul_permute_tbl[coeffIdx >> 1]);
+    const int8x16_t filter = vld1q_s8(matmul_luma_filter[coeffIdx - 1]);
+    const int16x8_t c = vdupq_n_s16(offset);
     int blkheight = height;
 
     src -= N_TAPS / 2 - 1;
@@ -522,11 +352,6 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
         blkheight += N_TAPS - 1;
     }
 
-    const uint8x16x2_t tbl = vld1q_u8_x2(matmul_permute_tbl[coeffIdx >> 1]);
-    const int8x16_t filter = vld1q_s8(matmul_luma_filter[coeffIdx - 1]);
-
-    const int16x8_t c = vdupq_n_s16(offset);
-
     for (int row = 0; row + 4 <= blkheight; row += 4)
     {
         int col = 0;
@@ -539,16 +364,16 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                 load_u8x16xn<4>(src + col + 8, srcStride, s_hi);
 
                 int16x8_t d_lo[4];
-                d_lo[0] = filter8_8_ps_matmul(s_lo[0], filter, c, tbl);
-                d_lo[1] = filter8_8_ps_matmul(s_lo[1], filter, c, tbl);
-                d_lo[2] = filter8_8_ps_matmul(s_lo[2], filter, c, tbl);
-                d_lo[3] = filter8_8_ps_matmul(s_lo[3], filter, c, tbl);
+                d_lo[0] = filter8_8_ps_matmul<coeff2>(s_lo[0], filter, c, tbl);
+                d_lo[1] = filter8_8_ps_matmul<coeff2>(s_lo[1], filter, c, tbl);
+                d_lo[2] = filter8_8_ps_matmul<coeff2>(s_lo[2], filter, c, tbl);
+                d_lo[3] = filter8_8_ps_matmul<coeff2>(s_lo[3], filter, c, tbl);
 
                 int16x8_t d_hi[4];
-                d_hi[0] = filter8_8_ps_matmul(s_hi[0], filter, c, tbl);
-                d_hi[1] = filter8_8_ps_matmul(s_hi[1], filter, c, tbl);
-                d_hi[2] = filter8_8_ps_matmul(s_hi[2], filter, c, tbl);
-                d_hi[3] = filter8_8_ps_matmul(s_hi[3], filter, c, tbl);
+                d_hi[0] = filter8_8_ps_matmul<coeff2>(s_hi[0], filter, c, tbl);
+                d_hi[1] = filter8_8_ps_matmul<coeff2>(s_hi[1], filter, c, tbl);
+                d_hi[2] = filter8_8_ps_matmul<coeff2>(s_hi[2], filter, c, tbl);
+                d_hi[3] = filter8_8_ps_matmul<coeff2>(s_hi[3], filter, c, tbl);
 
                 store_s16x8xn<4>(dst + col + 0, dstStride, d_lo);
                 store_s16x8xn<4>(dst + col + 8, dstStride, d_hi);
@@ -562,10 +387,10 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
                 load_u8x16xn<4>(src + col, srcStride, s);
 
                 int16x8_t d[4];
-                d[0] = filter8_8_ps_matmul(s[0], filter, c, tbl);
-                d[1] = filter8_8_ps_matmul(s[1], filter, c, tbl);
-                d[2] = filter8_8_ps_matmul(s[2], filter, c, tbl);
-                d[3] = filter8_8_ps_matmul(s[3], filter, c, tbl);
+                d[0] = filter8_8_ps_matmul<coeff2>(s[0], filter, c, tbl);
+                d[1] = filter8_8_ps_matmul<coeff2>(s[1], filter, c, tbl);
+                d[2] = filter8_8_ps_matmul<coeff2>(s[2], filter, c, tbl);
+                d[3] = filter8_8_ps_matmul<coeff2>(s[3], filter, c, tbl);
 
                 store_s16x8xn<4>(dst + col, dstStride, d);
             }
@@ -576,10 +401,10 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<4>(src + col, srcStride, s);
 
             int16x4_t d[4];
-            d[0] = filter8_4_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps_matmul(s[2], filter, c, tbl);
-            d[3] = filter8_4_ps_matmul(s[3], filter, c, tbl);
+            d[0] = filter8_4_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_4_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_4_ps_matmul<coeff2>(s[2], filter, c, tbl);
+            d[3] = filter8_4_ps_matmul<coeff2>(s[3], filter, c, tbl);
 
             store_s16x4xn<4>(dst + col, dstStride, d);
         }
@@ -598,9 +423,9 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<3>(src + col, srcStride, s);
 
             int16x8_t d[3];
-            d[0] = filter8_8_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_8_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_8_ps_matmul(s[2], filter, c, tbl);
+            d[0] = filter8_8_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_8_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_8_ps_matmul<coeff2>(s[2], filter, c, tbl);
 
             store_s16x8xn<3>(dst + col, dstStride, d);
         }
@@ -611,9 +436,9 @@ void inline interp8_horiz_ps_matmul(const uint8_t *src, intptr_t srcStride,
             load_u8x16xn<3>(src + col, srcStride, s);
 
             int16x4_t d[3];
-            d[0] = filter8_4_ps_matmul(s[0], filter, c, tbl);
-            d[1] = filter8_4_ps_matmul(s[1], filter, c, tbl);
-            d[2] = filter8_4_ps_matmul(s[2], filter, c, tbl);
+            d[0] = filter8_4_ps_matmul<coeff2>(s[0], filter, c, tbl);
+            d[1] = filter8_4_ps_matmul<coeff2>(s[1], filter, c, tbl);
+            d[2] = filter8_4_ps_matmul<coeff2>(s[2], filter, c, tbl);
 
             store_s16x4xn<3>(dst + col, dstStride, d);
         }
@@ -626,16 +451,14 @@ void interp8_horiz_ps_i8mm(const uint8_t *src, intptr_t srcStride, int16_t *dst,
 {
     switch (coeffIdx)
     {
-    case 1:
-        return interp8_horiz_ps_matmul<1, width, height>(src, srcStride, dst,
-                                                         dstStride, isRowExt);
     case 2:
-        return interp8_horiz_ps_dotprod<width, height>(src, srcStride, dst,
-                                                       dstStride, coeffIdx,
-                                                       isRowExt);
-    case 3:
-        return interp8_horiz_ps_matmul<3, width, height>(src, srcStride, dst,
-                                                         dstStride, isRowExt);
+        return interp8_horiz_ps_matmul<true, width, height>(src, srcStride, dst,
+                                                            dstStride, coeffIdx,
+                                                            isRowExt);
+    default:
+        return interp8_horiz_ps_matmul<false, width, height>(src, srcStride, dst,
+                                                             dstStride, coeffIdx,
+                                                             isRowExt);
     }
 }
 
-- 
2.39.5 (Apple Git-154)



More information about the x265-devel mailing list