[FFmpeg-devel] [PATCH 1/5] avcodec/vvc_mc: split the SAD dsp prototype into one function per blocksize width

James Almer jamrial at gmail.com
Thu May 23 15:27:12 EEST 2024


Signed-off-by: James Almer <jamrial at gmail.com>
---
 libavcodec/vvc/dsp.h             |  2 +-
 libavcodec/vvc/inter.c           |  6 ++++--
 libavcodec/vvc/inter_template.c  |  6 +++++-
 libavcodec/x86/vvc/vvc_sad.asm   | 32 ++++++++++++++++++++++++++------
 libavcodec/x86/vvc/vvcdsp_init.c | 22 +++++++++++++++++-----
 tests/checkasm/vvc_mc.c          |  3 ++-
 6 files changed, 55 insertions(+), 16 deletions(-)

diff --git a/libavcodec/vvc/dsp.h b/libavcodec/vvc/dsp.h
index 1f14096c41..55c4c81f53 100644
--- a/libavcodec/vvc/dsp.h
+++ b/libavcodec/vvc/dsp.h
@@ -99,7 +99,7 @@ typedef struct VVCInterDSPContext {
 
     void (*apply_bdof)(uint8_t *dst, ptrdiff_t dst_stride, int16_t *src0, int16_t *src1, int block_w, int block_h);
 
-    int (*sad)(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
+    int (*sad[5])(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
     void (*dmvr[2][2])(int16_t *dst, const uint8_t *src, ptrdiff_t src_stride, int height,
         intptr_t mx, intptr_t my, int width);
 } VVCInterDSPContext;
diff --git a/libavcodec/vvc/inter.c b/libavcodec/vvc/inter.c
index e1011b4fa1..0214e46634 100644
--- a/libavcodec/vvc/inter.c
+++ b/libavcodec/vvc/inter.c
@@ -740,6 +740,8 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
     const AVFrame *ref0, const AVFrame *ref1, const int x_off, const int y_off, const int block_w, const int block_h)
 {
     const VVCFrameContext *fc   = lc->fc;
+    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
+    const int tab               = sad_tab[(FFALIGN(block_w, 8) >> 3) - 1];
     const int sr_range          = 2;
     const AVFrame *ref[]        = { ref0, ref1 };
     int16_t *tmp[]              = { lc->tmp, lc->tmp1 };
@@ -763,7 +765,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
         fc->vvcdsp.inter.dmvr[!!my][!!mx](tmp[i], src, src_stride, pred_h, mx, my, pred_w);
     }
 
-    min_sad = fc->vvcdsp.inter.sad(tmp[L0], tmp[L1], dx, dy, block_w, block_h);
+    min_sad = fc->vvcdsp.inter.sad[tab](tmp[L0], tmp[L1], dx, dy, block_w, block_h);
     min_sad -= min_sad >> 2;
     sad[dy][dx] = min_sad;
 
@@ -773,7 +775,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
         for (dy = 0; dy < SAD_ARRAY_SIZE; dy++) {
             for (dx = 0; dx < SAD_ARRAY_SIZE; dx++) {
                 if (dx != sr_range || dy != sr_range) {
-                    sad[dy][dx] = fc->vvcdsp.inter.sad(lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
+                    sad[dy][dx] = fc->vvcdsp.inter.sad[tab](lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
                     if (sad[dy][dx] < min_sad) {
                         min_sad = sad[dy][dx];
                         min_dx = dx;
diff --git a/libavcodec/vvc/inter_template.c b/libavcodec/vvc/inter_template.c
index a8068f4ba8..34485321d3 100644
--- a/libavcodec/vvc/inter_template.c
+++ b/libavcodec/vvc/inter_template.c
@@ -626,7 +626,11 @@ static void FUNC(ff_vvc_inter_dsp_init)(VVCInterDSPContext *const inter)
     inter->apply_prof_uni_w     = FUNC(apply_prof_uni_w);
     inter->apply_bdof           = FUNC(apply_bdof);
     inter->prof_grad_filter     = FUNC(prof_grad_filter);
-    inter->sad                  = vvc_sad;
+    inter->sad[0]               =
+    inter->sad[1]               =
+    inter->sad[2]               =
+    inter->sad[3]               =
+    inter->sad[4]               = vvc_sad;
 }
 
 #undef FUNCS
diff --git a/libavcodec/x86/vvc/vvc_sad.asm b/libavcodec/x86/vvc/vvc_sad.asm
index b468d89ac2..a20818530f 100644
--- a/libavcodec/x86/vvc/vvc_sad.asm
+++ b/libavcodec/x86/vvc/vvc_sad.asm
@@ -51,7 +51,7 @@ SECTION .text
 
 INIT_YMM avx2
 
-cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+cglobal vvc_sad_8, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
     movsxdifnidn    dxq, dxd
     movsxdifnidn    dyq, dyd
 
@@ -76,10 +76,6 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
     pxor               m3, m3
     vpbroadcastd       m4, [pw_1]
 
-    cmp          block_wd, 16
-    jge    vvc_sad_16_128
-
-    vvc_sad_8:
         .loop_height:
         movu              xm0, [src1q]
         vinserti128        m0, m0, [src1q + MAX_PB_SIZE * ROWS * 2], 1
@@ -100,7 +96,31 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
         movd          eax, xm0
     RET
 
-    vvc_sad_16_128:
+cglobal vvc_sad_16, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+    movsxdifnidn    dxq, dxd
+    movsxdifnidn    dyq, dyd
+
+    sub             dxq, 2
+    sub             dyq, 2
+
+    mov             off1q, 2
+    mov             off2q, 2
+
+    add             off1q, dyq
+    sub             off2q, dyq
+
+    shl             off1q, 7
+    shl             off2q, 7
+
+    add             off1q, dxq
+    sub             off2q, dxq
+
+    lea             src1q, [src1q + off1q * 2 + 2 * 2]
+    lea             src2q, [src2q + off2q * 2 + 2 * 2]
+
+    pxor               m3, m3
+    vpbroadcastd       m4, [pw_1]
+
         sar      block_wd, 4
         .loop_height:
         mov         off1q, src1q
diff --git a/libavcodec/x86/vvc/vvcdsp_init.c b/libavcodec/x86/vvc/vvcdsp_init.c
index 4b4a2aa937..bd60963432 100644
--- a/libavcodec/x86/vvc/vvcdsp_init.c
+++ b/libavcodec/x86/vvc/vvcdsp_init.c
@@ -312,8 +312,20 @@ ALF_FUNCS(16, 12, avx2)
     c->alf.classify       = ff_vvc_alf_classify_##bd##_avx2;         \
 } while (0)
 
-int ff_vvc_sad_avx2(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
-#define SAD_INIT() c->inter.sad = ff_vvc_sad_avx2
+#define SAD_PROTOTYPE(w, opt)                                        \
+int bf(ff_vvc_sad, w, opt)(const int16_t *src0, const int16_t *src1, \
+                           int dx, int dy, int block_w, int block_h) \
+
+SAD_PROTOTYPE(8,   avx2);
+SAD_PROTOTYPE(16,  avx2);
+
+#define SAD_INIT(opt) do {                   \
+    c->inter.sad[0] = ff_vvc_sad_8_##opt;    \
+    c->inter.sad[1] =                        \
+    c->inter.sad[2] =                        \
+    c->inter.sad[3] =                        \
+    c->inter.sad[4] = ff_vvc_sad_16_##opt;   \
+} while (0)
 #endif
 
 void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
@@ -330,7 +342,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             ALF_INIT(8);
             AVG_INIT(8, avx2);
             MC_LINKS_AVX2(8);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     case 10:
@@ -342,7 +354,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             AVG_INIT(10, avx2);
             MC_LINKS_AVX2(10);
             MC_LINKS_16BPC_AVX2(10);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     case 12:
@@ -354,7 +366,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             AVG_INIT(12, avx2);
             MC_LINKS_AVX2(12);
             MC_LINKS_16BPC_AVX2(12);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     default:
diff --git a/tests/checkasm/vvc_mc.c b/tests/checkasm/vvc_mc.c
index 1e889e2cff..deae1014d2 100644
--- a/tests/checkasm/vvc_mc.c
+++ b/tests/checkasm/vvc_mc.c
@@ -327,6 +327,7 @@ static void check_avg(void)
 static void check_vvc_sad(void)
 {
     const int bit_depth = 10;
+    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
     VVCDSPContext c;
     LOCAL_ALIGNED_32(uint16_t, src0, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
     LOCAL_ALIGNED_32(uint16_t, src1, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
@@ -341,7 +342,7 @@ static void check_vvc_sad(void)
         for (int w = 8; w <= MAX_CTU_SIZE; w *= 2) {
             for(int offy = 0; offy <= 4; offy++) {
                 for(int offx = 0; offx <= 4; offx++) {
-                    if(check_func(c.inter.sad, "sad_%dx%d", w, h)) {
+                    if(check_func(c.inter.sad[sad_tab[(w >> 3) - 1]], "sad_%dx%d", w, h)) {
                         int result0;
                         int result1;
 
-- 
2.45.1



More information about the ffmpeg-devel mailing list