[FFmpeg-devel] [PATCH v2 2/2] swscale/aarch64: add hscale specializations

Swinney, Jonathan jswinney at amazon.com
Wed May 25 04:21:33 EEST 2022


This patch adds code to support specializations of the hscale function and adds
a specialization for filterSize == 4.

ff_hscale8to15_4_neon is a complete rewrite. Since the main bottleneck here is
loading the data from src, this data is loaded a whole block ahead and stored
back to the stack to be loaded again with ld4. This arranges the data for most
efficient use of the vector instructions and removes the need for completion
adds at the end. The number of iterations of the C per iteration of the assembly
is increased from 4 to 8, but because of the prefetching, there must be a
special section without prefetching when dstW < 16.

This improves speed on Graviton 2 (Neoverse N1) dramatically in the case where
previously fs=8 would have been required.

before: hscale_8_to_15__fs_8_dstW_512_neon: 1962.8
after : hscale_8_to_15__fs_4_dstW_512_neon: 1220.9

Signed-off-by: Jonathan Swinney <jswinney at amazon.com>
---
 libswscale/aarch64/hscale.S  | 172 ++++++++++++++++++++++++++++++++++-
 libswscale/aarch64/swscale.c |  40 ++++++--
 libswscale/utils.c           |   2 +-
 3 files changed, 203 insertions(+), 11 deletions(-)

diff --git a/libswscale/aarch64/hscale.S b/libswscale/aarch64/hscale.S
index da34f1cb8d..60bcd783e7 100644
--- a/libswscale/aarch64/hscale.S
+++ b/libswscale/aarch64/hscale.S
@@ -1,5 +1,7 @@
 /*
  * Copyright (c) 2016 Clément Bœsch <clement stupeflix.com>
+ * Copyright (c) 2019-2021 Sebastian Pop <spop at amazon.com>
+ * Copyright (c) 2022 Jonathan Swinney <jswinney at amazon.com>
  *
  * This file is part of FFmpeg.
  *
@@ -20,7 +22,25 @@
 
 #include "libavutil/aarch64/asm.S"
 
-function ff_hscale_8_to_15_neon, export=1
+/*
+;-----------------------------------------------------------------------------
+; horizontal line scaling
+;
+; void hscale<source_width>to<intermediate_nbits>_<filterSize>_<opt>
+;                               (SwsContext *c, int{16,32}_t *dst,
+;                                int dstW, const uint{8,16}_t *src,
+;                                const int16_t *filter,
+;                                const int32_t *filterPos, int filterSize);
+;
+; Scale one horizontal line. Input is either 8-bit width or 16-bit width
+; ($source_width can be either 8, 9, 10 or 16, difference is whether we have to
+; downscale before multiplying). Filter is 14 bits. Output is either 15 bits
+; (in int16_t) or 19 bits (in int32_t), as given in $intermediate_nbits. Each
+; output pixel is generated from $filterSize input pixels, the position of
+; the first pixel is given in filterPos[nOutputPixel].
+;----------------------------------------------------------------------------- */
+
+function ff_hscale8to15_X8_neon, export=1
         sbfiz               x7, x6, #1, #32             // filterSize*2 (*2 because int16)
 1:      ldr                 w8, [x5], #4                // filterPos[idx]
         ldr                 w0, [x5], #4                // filterPos[idx + 1]
@@ -70,3 +90,153 @@ function ff_hscale_8_to_15_neon, export=1
         b.gt                1b                          // loop until end of line
         ret
 endfunc
+
+function ff_hscale8to15_4_neon, export=1
+// x0  SwsContext *c (not used)
+// x1  int16_t *dst
+// x2  int dstW
+// x3  const uint8_t *src
+// x4  const int16_t *filter
+// x5  const int32_t *filterPos
+// x6  int filterSize
+// x8-x15 registers for gathering src data
+
+// v0      madd accumulator 4S
+// v1-v4   filter values (16 bit) 8H
+// v5      madd accumulator 4S
+// v16-v19 src values (8 bit) 8B
+
+// This implementation has 4 sections:
+//  1. Prefetch src data
+//  2. Interleaved prefetching src data and madd
+//  3. Complete madd
+//  4. Complete remaining iterations when dstW % 8 != 0
+
+        add                 sp, sp, #-32                // allocate 32 bytes on the stack
+        cmp                 w2, #16                     // if dstW <16, skip to the last block used for wrapping up
+        b.lt                2f
+
+        // load 8 values from filterPos to be used as offsets into src
+        ldp                 w8, w9,  [x5]               // filterPos[idx + 0], [idx + 1]
+        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2], [idx + 3]
+        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4], [idx + 5]
+        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6], [idx + 7]
+        add                 x5, x5, #32                 // advance filterPos
+
+        // gather random access data from src into contiguous memory
+        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]][0..3]
+        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]][0..3]
+        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]][0..3]
+        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]][0..3]
+        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]][0..3]
+        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]][0..3]
+        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]][0..3]
+        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]][0..3]
+        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
+        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
+        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
+        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
+
+1:
+        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp] // transpose 8 bytes each from src into 4 registers
+
+        // load 8 values from filterPos to be used as offsets into src
+        ldp                 w8, w9,  [x5]               // filterPos[idx + 0][0..3], [idx + 1][0..3], next iteration
+        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2][0..3], [idx + 3][0..3], next iteration
+        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4][0..3], [idx + 5][0..3], next iteration
+        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6][0..3], [idx + 7][0..3], next iteration
+
+        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
+        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
+
+        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
+
+        add                 x5, x5, #32                 // advance filterPos
+
+        // interleaved SIMD and prefetching intended to keep ld/st and vector pipelines busy
+        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
+        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]], next iteration
+        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]], next iteration
+        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
+        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]], next iteration
+        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]], next iteration
+
+        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
+        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
+        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]], next iteration
+        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]], next iteration
+        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
+        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
+        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]], next iteration
+        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]], next iteration
+
+        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
+        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
+        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
+        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
+        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
+        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
+        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
+        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
+
+        sub                 w2, w2, #8                  // dstW -= 8
+        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
+        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
+        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
+        cmp                 w2, #16                     // continue on main loop if there are at least 16 iterations left
+        b.ge                1b
+
+        // last full iteration
+        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp]
+        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
+
+        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
+        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
+
+        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
+
+        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
+        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
+        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
+        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
+
+        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
+        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
+        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
+        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
+
+        subs                w2, w2, #8                  // dstW -= 8
+        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
+        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
+        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
+
+        cbnz                w2, 2f                      // if >0 iterations remain, jump to the wrap up section
+
+        add                 sp, sp, #32                 // clean up stack
+        ret
+
+        // finish up when dstW % 8 != 0 or dstW < 16
+2:
+        // load src
+        ldr                 w8, [x5], #4                // filterPos[i]
+        add                 x9, x3, w8, UXTW            // calculate the address for src load
+        ld1                 {v5.S}[0], [x9]             // src[filterPos[i] + 0..3]
+        // load filter
+        ld1                 {v6.4H}, [x4], #8           // filter[filterSize * i + 0..3]
+
+        uxtl                v5.8H, v5.8B                // unsigned exten long, convert src data to 16-bit
+        smull               v0.4S, v5.4H, v6.4H         // 4 iterations of src[...] * filter[...]
+        addv                s0, v0.4S                   // add up products of src and filter values
+        sqshrn              h0, s0, #7                  // shift and clip the 2x16-bit final value
+        st1                 {v0.H}[0], [x1], #2         // dst[i] = ...
+        sub                 w2, w2, #1                  // dstW--
+        cbnz                w2, 2b
+
+        add                 sp, sp, #32                 // clean up stack
+        ret
+endfunc
diff --git a/libswscale/aarch64/swscale.c b/libswscale/aarch64/swscale.c
index 09d0a7130e..583e385825 100644
--- a/libswscale/aarch64/swscale.c
+++ b/libswscale/aarch64/swscale.c
@@ -22,25 +22,47 @@
 #include "libswscale/swscale_internal.h"
 #include "libavutil/aarch64/cpu.h"
 
-void ff_hscale_8_to_15_neon(SwsContext *c, int16_t *dst, int dstW,
-                            const uint8_t *src, const int16_t *filter,
-                            const int32_t *filterPos, int filterSize);
+#define SCALE_FUNC(filter_n, from_bpc, to_bpc, opt) \
+void ff_hscale ## from_bpc ## to ## to_bpc ## _ ## filter_n ## _ ## opt( \
+                                                SwsContext *c, int16_t *data, \
+                                                int dstW, const uint8_t *src, \
+                                                const int16_t *filter, \
+                                                const int32_t *filterPos, int filterSize)
+#define SCALE_FUNCS(filter_n, opt) \
+    SCALE_FUNC(filter_n,  8, 15, opt);
+#define ALL_SCALE_FUNCS(opt) \
+    SCALE_FUNCS(4, opt); \
+    SCALE_FUNCS(8, opt); \
+    SCALE_FUNCS(X8, opt)
+
+ALL_SCALE_FUNCS(neon);
 
 void ff_yuv2planeX_8_neon(const int16_t *filter, int filterSize,
                           const int16_t **src, uint8_t *dest, int dstW,
                           const uint8_t *dither, int offset);
 
+#define ASSIGN_SCALE_FUNC2(hscalefn, filtersize, opt) do {              \
+    if (c->srcBpc == 8 && c->dstBpc <= 14) {                            \
+      hscalefn =                                                        \
+        ff_hscale8to15_ ## filtersize ## _ ## opt;                      \
+    }                                                                   \
+} while (0)
+
+#define ASSIGN_SCALE_FUNC(hscalefn, filtersize, opt)                    \
+  switch (filtersize) {                                                 \
+  case 4:  ASSIGN_SCALE_FUNC2(hscalefn, 4, opt); break;                 \
+  default: if (filtersize % 8 == 0)                                     \
+               ASSIGN_SCALE_FUNC2(hscalefn, X8, opt);                   \
+           break;                                                       \
+  }
+
 av_cold void ff_sws_init_swscale_aarch64(SwsContext *c)
 {
     int cpu_flags = av_get_cpu_flags();
 
     if (have_neon(cpu_flags)) {
-        if (c->srcBpc == 8 && c->dstBpc <= 14 &&
-            (c->hLumFilterSize % 8) == 0 &&
-            (c->hChrFilterSize % 8) == 0)
-        {
-            c->hyScale = c->hcScale = ff_hscale_8_to_15_neon;
-        }
+        ASSIGN_SCALE_FUNC(c->hyScale, c->hLumFilterSize, neon);
+        ASSIGN_SCALE_FUNC(c->hcScale, c->hChrFilterSize, neon);
         if (c->dstBpc == 8) {
             c->yuv2planeX = ff_yuv2planeX_8_neon;
         }
diff --git a/libswscale/utils.c b/libswscale/utils.c
index ffa130524a..105781c4f4 100644
--- a/libswscale/utils.c
+++ b/libswscale/utils.c
@@ -1820,7 +1820,7 @@ av_cold int sws_init_context(SwsContext *c, SwsFilter *srcFilter,
         {
             const int filterAlign = X86_MMX(cpu_flags)     ? 4 :
                                     PPC_ALTIVEC(cpu_flags) ? 8 :
-                                    have_neon(cpu_flags)   ? 8 : 1;
+                                    have_neon(cpu_flags)   ? 4 : 1;
 
             if ((ret = initFilter(&c->hLumFilter, &c->hLumFilterPos,
                            &c->hLumFilterSize, c->lumXInc,
-- 
2.32.0



More information about the ffmpeg-devel mailing list