[x265] [PATCH 2/2] AArch64: Add SBD and HBD Neon implementation of weight_pp

Micro Daryl Robles microdaryl.robles at arm.com
Mon Apr 7 10:57:30 UTC 2025


Add Neon intrinsic implementation of weight_pp that works for both SBD
and HBD. Remove the Neon asm for SBD as the intrinsics implementation
improves performance for the more general case where CTZ(w0) < shift.

Update the test and speedup benchmarks to test both conditions where
CTZ(w0) >= shift and CTZ(w0) < shift.

Relative performance compared to Neon asm [SBD]:
 (w0 = 64)
 Neoverse N1: 1.19x
 Neoverse N2: 1.00x
 Neoverse V1: 1.10x
 Neoverse V2: 1.01x
 (w0 = 127)
 Neoverse N1: 3.05x
 Neoverse N2: 3.63x
 Neoverse V1: 3.25x
 Neoverse V2: 3.58x

Relative performance compared to scalar C [HBD]:
 Neoverse N1: 1.53x
 Neoverse N2: 2.03x
 Neoverse V1: 1.71x
 Neoverse V2: 1.65x
---
 source/common/aarch64/asm-primitives.cpp |   1 -
 source/common/aarch64/fun-decls.h        |   1 -
 source/common/aarch64/pixel-prim.cpp     | 129 +++++++++++++++++++++++
 source/common/aarch64/pixel-util.S       | 102 ------------------
 source/test/pixelharness.cpp             |  42 +++++---
 5 files changed, 158 insertions(+), 117 deletions(-)

diff --git a/source/common/aarch64/asm-primitives.cpp b/source/common/aarch64/asm-primitives.cpp
index 621dbf334..f96d7a426 100644
--- a/source/common/aarch64/asm-primitives.cpp
+++ b/source/common/aarch64/asm-primitives.cpp
@@ -732,7 +732,6 @@ void setupNeonPrimitives(EncoderPrimitives &p)
     // psy_cost_pp
     p.cu[BLOCK_4x4].psy_cost_pp = PFX(psyCost_4x4_neon);
 
-    p.weight_pp = PFX(weight_pp_neon);
 #if !defined(__APPLE__)
     p.scanPosLast = PFX(scanPosLast_neon);
 #endif
diff --git a/source/common/aarch64/fun-decls.h b/source/common/aarch64/fun-decls.h
index 22fefb398..12383b573 100644
--- a/source/common/aarch64/fun-decls.h
+++ b/source/common/aarch64/fun-decls.h
@@ -226,7 +226,6 @@ void PFX(ssim_4x4x2_core_neon(const pixel* pix1, intptr_t stride1, const pixel*
 
 int PFX(psyCost_4x4_neon)(const pixel* source, intptr_t sstride, const pixel* recon, intptr_t rstride);
 int PFX(psyCost_8x8_neon)(const pixel* source, intptr_t sstride, const pixel* recon, intptr_t rstride);
-void PFX(weight_pp_neon)(const pixel* src, pixel* dst, intptr_t stride, int width, int height, int w0, int round, int shift, int offset);
 void PFX(weight_sp_neon)(const int16_t* src, pixel* dst, intptr_t srcStride, intptr_t dstStride, int width, int height, int w0, int round, int shift, int offset);
 int PFX(scanPosLast_neon)(const uint16_t *scan, const coeff_t *coeff, uint16_t *coeffSign, uint16_t *coeffFlag, uint8_t *coeffNum, int numSig, const uint16_t* scanCG4x4, const int trSize);
 uint32_t PFX(costCoeffNxN_neon)(const uint16_t *scan, const coeff_t *coeff, intptr_t trSize, uint16_t *absCoeff, const uint8_t *tabSigCtx, uint32_t scanFlagMask, uint8_t *baseCtx, int offset, int scanPosSigOff, int subPosBase);
diff --git a/source/common/aarch64/pixel-prim.cpp b/source/common/aarch64/pixel-prim.cpp
index 4a7831428..7ee807d3b 100644
--- a/source/common/aarch64/pixel-prim.cpp
+++ b/source/common/aarch64/pixel-prim.cpp
@@ -1141,6 +1141,133 @@ void planecopy_cp_neon(const uint8_t *src, intptr_t srcStride, pixel *dst,
     while (--height != 0);
 }
 
+void weight_pp_neon(const pixel *src, pixel *dst, intptr_t stride, int width, int height,
+                    int w0, int round, int shift, int offset)
+{
+    const int correction = IF_INTERNAL_PREC - X265_DEPTH;
+
+    X265_CHECK(height >= 1, "height length error\n");
+    X265_CHECK(width >= 16, "width length error\n");
+    X265_CHECK(!(width & 15), "width alignment error\n");
+    X265_CHECK(w0 >= 0, "w0 should be min 0\n");
+    X265_CHECK(w0 < 128, "w0 should be max 127\n");
+    X265_CHECK(shift >= correction, "shift must include factor correction\n");
+    X265_CHECK((round & ((1 << correction) - 1)) == 0,
+               "round must include factor correction\n");
+
+    (void)round;
+
+#if HIGH_BIT_DEPTH
+    int32x4_t corrected_shift = vdupq_n_s32(correction - shift);
+
+    do
+    {
+        int w = 0;
+        do
+        {
+            int16x8_t s0 = vreinterpretq_s16_u16(vld1q_u16(src + w + 0));
+            int16x8_t s1 = vreinterpretq_s16_u16(vld1q_u16(src + w + 8));
+            int32x4_t weighted_s0_lo = vmull_n_s16(vget_low_s16(s0), w0);
+            int32x4_t weighted_s0_hi = vmull_n_s16(vget_high_s16(s0), w0);
+            int32x4_t weighted_s1_lo = vmull_n_s16(vget_low_s16(s1), w0);
+            int32x4_t weighted_s1_hi = vmull_n_s16(vget_high_s16(s1), w0);
+            weighted_s0_lo = vrshlq_s32(weighted_s0_lo, corrected_shift);
+            weighted_s0_hi = vrshlq_s32(weighted_s0_hi, corrected_shift);
+            weighted_s1_lo = vrshlq_s32(weighted_s1_lo, corrected_shift);
+            weighted_s1_hi = vrshlq_s32(weighted_s1_hi, corrected_shift);
+            weighted_s0_lo = vaddq_s32(weighted_s0_lo, vdupq_n_s32(offset));
+            weighted_s0_hi = vaddq_s32(weighted_s0_hi, vdupq_n_s32(offset));
+            weighted_s1_lo = vaddq_s32(weighted_s1_lo, vdupq_n_s32(offset));
+            weighted_s1_hi = vaddq_s32(weighted_s1_hi, vdupq_n_s32(offset));
+            uint16x4_t t0_lo = vqmovun_s32(weighted_s0_lo);
+            uint16x4_t t0_hi = vqmovun_s32(weighted_s0_hi);
+            uint16x4_t t1_lo = vqmovun_s32(weighted_s1_lo);
+            uint16x4_t t1_hi = vqmovun_s32(weighted_s1_hi);
+            uint16x8_t d0 = vminq_u16(vcombine_u16(t0_lo, t0_hi), vdupq_n_u16(PIXEL_MAX));
+            uint16x8_t d1 = vminq_u16(vcombine_u16(t1_lo, t1_hi), vdupq_n_u16(PIXEL_MAX));
+
+            vst1q_u16(dst + w + 0, d0);
+            vst1q_u16(dst + w + 8, d1);
+            w += 16;
+        }
+        while (w != width);
+
+        src += stride;
+        dst += stride;
+    }
+    while (--height != 0);
+
+#else
+    // Re-arrange the shift operations.
+    // Then, hoist the right shift out of the loop if CTZ(w0) >= shift - correction.
+    // Orig: (((src[x] << correction) * w0 + round) >> shift) + offset.
+    // New: (src[x] * (w0 >> shift - correction)) + (round >> shift) + offset.
+    // (round >> shift) is always zero since round = 1 << (shift - 1).
+
+    unsigned long id;
+    CTZ(id, w0);
+
+    if ((int)id >= shift - correction)
+    {
+        w0 >>= shift - correction;
+
+        do
+        {
+            int w = 0;
+            do
+            {
+                uint8x16_t s = vld1q_u8(src + w);
+                int16x8_t weighted_s0 = vreinterpretq_s16_u16(
+                    vmlal_u8(vdupq_n_u16(offset), vget_low_u8(s), vdup_n_u8(w0)));
+                int16x8_t weighted_s1 = vreinterpretq_s16_u16(
+                    vmlal_u8(vdupq_n_u16(offset), vget_high_u8(s), vdup_n_u8(w0)));
+                uint8x8_t d0 = vqmovun_s16(weighted_s0);
+                uint8x8_t d1 = vqmovun_s16(weighted_s1);
+
+                vst1q_u8(dst + w, vcombine_u8(d0, d1));
+                w += 16;
+            }
+            while (w != width);
+
+            src += stride;
+            dst += stride;
+        }
+        while (--height != 0);
+    }
+    else // Keep rounding shifts within the loop.
+    {
+        int16x8_t corrected_shift = vdupq_n_s16(correction - shift);
+
+        do
+        {
+            int w = 0;
+            do
+            {
+                uint8x16_t s = vld1q_u8(src + w);
+                int16x8_t weighted_s0 =
+                    vreinterpretq_s16_u16(vmull_u8(vget_low_u8(s), vdup_n_u8(w0)));
+                int16x8_t weighted_s1 =
+                    vreinterpretq_s16_u16(vmull_u8(vget_high_u8(s), vdup_n_u8(w0)));
+                weighted_s0 = vrshlq_s16(weighted_s0, corrected_shift);
+                weighted_s1 = vrshlq_s16(weighted_s1, corrected_shift);
+                weighted_s0 = vaddq_s16(weighted_s0, vdupq_n_s16(offset));
+                weighted_s1 = vaddq_s16(weighted_s1, vdupq_n_s16(offset));
+                uint8x8_t d0 = vqmovun_s16(weighted_s0);
+                uint8x8_t d1 = vqmovun_s16(weighted_s1);
+
+                vst1q_u8(dst + w, vcombine_u8(d0, d1));
+                w += 16;
+            }
+            while (w != width);
+
+            src += stride;
+            dst += stride;
+        }
+        while (--height != 0);
+    }
+#endif
+}
+
 template<int lx, int ly>
 void pixelavg_pp_neon(pixel *dst, intptr_t dstride, const pixel *src0, intptr_t sstride0, const pixel *src1,
                       intptr_t sstride1, int)
@@ -1765,6 +1892,8 @@ void setupPixelPrimitives_neon(EncoderPrimitives &p)
     p.chroma[X265_CSP_I422].cu[BLOCK_32x32].sa8d = sa8d16<16, 32>;
     p.chroma[X265_CSP_I422].cu[BLOCK_64x64].sa8d = sa8d16<32, 64>;
 
+    p.weight_pp = weight_pp_neon;
+
     p.planecopy_cp = planecopy_cp_neon;
 }
 
diff --git a/source/common/aarch64/pixel-util.S b/source/common/aarch64/pixel-util.S
index 72f8bbc8b..161c9a210 100644
--- a/source/common/aarch64/pixel-util.S
+++ b/source/common/aarch64/pixel-util.S
@@ -2059,108 +2059,6 @@ function PFX(normFact64_neon)
     ret
 endfunc
 
-// void weight_pp_c(const pixel* src, pixel* dst, intptr_t stride, int width, int height, int w0, int round, int shift, int offset)
-function PFX(weight_pp_neon)
-    sub             x2, x2, x3
-    ldr             w9, [sp]              // offset
-    lsl             w5, w5, #6            // w0 << correction
-
-    // count trailing zeros in w5 and compare against shift right amount.
-    rbit            w10, w5
-    clz             w10, w10
-    cmp             w10, w7
-    b.lt            .unfoldedShift
-
-    // shift right only removes trailing zeros: hoist LSR out of the loop.
-    lsr             w10, w5, w7           // w0 << correction >> shift
-    dup             v25.16b, w10
-    lsr             w6, w6, w7            // round >> shift
-    add             w6, w6, w9            // round >> shift + offset
-    dup             v26.8h, w6
-
-    // Check arithmetic range.
-    mov             w11, #255
-    madd            w11, w11, w10, w6
-    add             w11, w11, w9
-    lsr             w11, w11, #16
-    cbnz            w11, .widenTo32Bit
-
-    // 16-bit arithmetic is enough.
-.LoopHpp:
-    mov             x12, x3
-.LoopWpp:
-    ldr             q0, [x0], #16
-    sub             x12, x12, #16
-    umull           v1.8h, v0.8b, v25.8b  // val *= w0 << correction >> shift
-    umull2          v2.8h, v0.16b, v25.16b
-    add             v1.8h, v1.8h, v26.8h  // val += round >> shift + offset
-    add             v2.8h, v2.8h, v26.8h
-    sqxtun          v0.8b, v1.8h          // val = x265_clip(val)
-    sqxtun2         v0.16b, v2.8h
-    str             q0, [x1], #16
-    cbnz            x12, .LoopWpp
-    add             x1, x1, x2
-    add             x0, x0, x2
-    sub             x4, x4, #1
-    cbnz            x4, .LoopHpp
-    ret
-
-    // 32-bit arithmetic is needed.
-.widenTo32Bit:
-.LoopHpp32:
-    mov             x12, x3
-.LoopWpp32:
-    ldr             d0, [x0], #8
-    sub             x12, x12, #8
-    uxtl            v0.8h, v0.8b
-    umull           v1.4s, v0.4h, v25.4h  // val *= w0 << correction >> shift
-    umull2          v2.4s, v0.8h, v25.8h
-    add             v1.4s, v1.4s, v26.4s  // val += round >> shift + offset
-    add             v2.4s, v2.4s, v26.4s
-    sqxtn           v0.4h, v1.4s          // val = x265_clip(val)
-    sqxtn2          v0.8h, v2.4s
-    sqxtun          v0.8b, v0.8h
-    str             d0, [x1], #8
-    cbnz            x12, .LoopWpp32
-    add             x1, x1, x2
-    add             x0, x0, x2
-    sub             x4, x4, #1
-    cbnz            x4, .LoopHpp32
-    ret
-
-    // The shift right cannot be moved out of the loop.
-.unfoldedShift:
-    dup             v25.8h, w5            // w0 << correction
-    dup             v26.4s, w6            // round
-    neg             w7, w7                // -shift
-    dup             v27.4s, w7
-    dup             v29.4s, w9            // offset
-.LoopHppUS:
-    mov             x12, x3
-.LoopWppUS:
-    ldr             d0, [x0], #8
-    sub             x12, x12, #8
-    uxtl            v0.8h, v0.8b
-    umull           v1.4s, v0.4h, v25.4h  // val *= w0
-    umull2          v2.4s, v0.8h, v25.8h
-    add             v1.4s, v1.4s, v26.4s  // val += round
-    add             v2.4s, v2.4s, v26.4s
-    sshl            v1.4s, v1.4s, v27.4s  // val >>= shift
-    sshl            v2.4s, v2.4s, v27.4s
-    add             v1.4s, v1.4s, v29.4s  // val += offset
-    add             v2.4s, v2.4s, v29.4s
-    sqxtn           v0.4h, v1.4s          // val = x265_clip(val)
-    sqxtn2          v0.8h, v2.4s
-    sqxtun          v0.8b, v0.8h
-    str             d0, [x1], #8
-    cbnz            x12, .LoopWppUS
-    add             x1, x1, x2
-    add             x0, x0, x2
-    sub             x4, x4, #1
-    cbnz            x4, .LoopHppUS
-    ret
-endfunc
-
 // int scanPosLast(
 //     const uint16_t *scan,      // x0
 //     const coeff_t *coeff,      // x1
diff --git a/source/test/pixelharness.cpp b/source/test/pixelharness.cpp
index f46f9ae3d..beb1e80bd 100644
--- a/source/test/pixelharness.cpp
+++ b/source/test/pixelharness.cpp
@@ -338,26 +338,33 @@ bool PixelHarness::check_weightp(weightp_pp_t ref, weightp_pp_t opt)
     if (cpuid & X265_CPU_AVX512)
         width = 32 * (rand() % 2 + 1);
     int height = 8;
-    int w0 = rand() % 128;
-    int shift = rand() % 8; // maximum is 7, see setFromWeightAndOffset()
+    int shift = (rand() % 6) + 1;
+    // Make CTZ(w0) >= shift; max of 126.
+    int w0 = (rand() % ((1 << (7 - shift)) - 1) + 1) << shift;
     int round = shift ? (1 << (shift - 1)) : 0;
     int offset = (rand() % 256) - 128;
     intptr_t stride = 64;
     const int correction = (IF_INTERNAL_PREC - X265_DEPTH);
-    for (int i = 0; i < ITERS; i++)
+
+    for (int k = 0; k < 2; k++)
     {
-        int index = i % TEST_CASES;
-        checked(opt, pixel_test_buff[index] + j, opt_dest, stride, width, height, w0, round << correction, shift + correction, offset);
-        ref(pixel_test_buff[index] + j, ref_dest, stride, width, height, w0, round << correction, shift + correction, offset);
+        w0 += k; // 1st: CTZ(w0) >= shift; 2nd: CTZ(w0) < shift
 
-        if (memcmp(ref_dest, opt_dest, 64 * 64 * sizeof(pixel)))
+        for (int i = 0; i < ITERS; i++)
         {
+            int index = i % TEST_CASES;
             checked(opt, pixel_test_buff[index] + j, opt_dest, stride, width, height, w0, round << correction, shift + correction, offset);
-            return false;
-        }
+            ref(pixel_test_buff[index] + j, ref_dest, stride, width, height, w0, round << correction, shift + correction, offset);
 
-        reportfail();
-        j += INCR;
+            if (memcmp(ref_dest, opt_dest, 64 * 64 * sizeof(pixel)))
+            {
+                checked(opt, pixel_test_buff[index] + j, opt_dest, stride, width, height, w0, round << correction, shift + correction, offset);
+                return false;
+            }
+
+            reportfail();
+            j += INCR;
+        }
     }
 
     return true;
@@ -3508,8 +3515,17 @@ void PixelHarness::measureSpeed(const EncoderPrimitives& ref, const EncoderPrimi
 
     if (opt.weight_pp)
     {
-        HEADER0("weight_pp");
-        REPORT_SPEEDUP(opt.weight_pp, ref.weight_pp, pbuf1, pbuf2, 64, 32, 32, 128, 1 << 9, 10, 100);
+        int w0[2] = {64, 127}; // max: 127. 1: CTZ(64) >= shift. 2: CTZ(127) < shift.
+        int shift = 6;
+        int round = 1 << (shift - 1);
+        int offset = 100; // -128 to 127
+        const int correction = IF_INTERNAL_PREC - X265_DEPTH;
+        for (int i = 0; i < 2; i++)
+        {
+            HEADER("weight_pp[w0=%d]", w0[i]);
+            REPORT_SPEEDUP(opt.weight_pp, ref.weight_pp, pbuf1, pbuf2, 64, 32, 32, w0[i],
+                           round << correction, shift + correction, offset);
+        }
     }
 
     if (opt.weight_sp)
-- 
2.34.1

-------------- next part --------------
A non-text attachment was scrubbed...
Name: 0002-AArch64-Add-SBD-and-HBD-Neon-implementation-of-weigh.patch
Type: text/x-diff
Size: 16239 bytes
Desc: not available
URL: <http://mailman.videolan.org/pipermail/x265-devel/attachments/20250407/f2d7cc83/attachment-0001.patch>


More information about the x265-devel mailing list