[FFmpeg-devel] [PATCH 1/5] avcodec/vvc_mc: split the SAD dsp prototype into one function per blocksize width
James Almer
jamrial at gmail.com
Thu May 23 15:27:12 EEST 2024
Signed-off-by: James Almer <jamrial at gmail.com>
---
libavcodec/vvc/dsp.h | 2 +-
libavcodec/vvc/inter.c | 6 ++++--
libavcodec/vvc/inter_template.c | 6 +++++-
libavcodec/x86/vvc/vvc_sad.asm | 32 ++++++++++++++++++++++++++------
libavcodec/x86/vvc/vvcdsp_init.c | 22 +++++++++++++++++-----
tests/checkasm/vvc_mc.c | 3 ++-
6 files changed, 55 insertions(+), 16 deletions(-)
diff --git a/libavcodec/vvc/dsp.h b/libavcodec/vvc/dsp.h
index 1f14096c41..55c4c81f53 100644
--- a/libavcodec/vvc/dsp.h
+++ b/libavcodec/vvc/dsp.h
@@ -99,7 +99,7 @@ typedef struct VVCInterDSPContext {
void (*apply_bdof)(uint8_t *dst, ptrdiff_t dst_stride, int16_t *src0, int16_t *src1, int block_w, int block_h);
- int (*sad)(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
+ int (*sad[5])(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
void (*dmvr[2][2])(int16_t *dst, const uint8_t *src, ptrdiff_t src_stride, int height,
intptr_t mx, intptr_t my, int width);
} VVCInterDSPContext;
diff --git a/libavcodec/vvc/inter.c b/libavcodec/vvc/inter.c
index e1011b4fa1..0214e46634 100644
--- a/libavcodec/vvc/inter.c
+++ b/libavcodec/vvc/inter.c
@@ -740,6 +740,8 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
const AVFrame *ref0, const AVFrame *ref1, const int x_off, const int y_off, const int block_w, const int block_h)
{
const VVCFrameContext *fc = lc->fc;
+ static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
+ const int tab = sad_tab[(FFALIGN(block_w, 8) >> 3) - 1];
const int sr_range = 2;
const AVFrame *ref[] = { ref0, ref1 };
int16_t *tmp[] = { lc->tmp, lc->tmp1 };
@@ -763,7 +765,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
fc->vvcdsp.inter.dmvr[!!my][!!mx](tmp[i], src, src_stride, pred_h, mx, my, pred_w);
}
- min_sad = fc->vvcdsp.inter.sad(tmp[L0], tmp[L1], dx, dy, block_w, block_h);
+ min_sad = fc->vvcdsp.inter.sad[tab](tmp[L0], tmp[L1], dx, dy, block_w, block_h);
min_sad -= min_sad >> 2;
sad[dy][dx] = min_sad;
@@ -773,7 +775,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
for (dy = 0; dy < SAD_ARRAY_SIZE; dy++) {
for (dx = 0; dx < SAD_ARRAY_SIZE; dx++) {
if (dx != sr_range || dy != sr_range) {
- sad[dy][dx] = fc->vvcdsp.inter.sad(lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
+ sad[dy][dx] = fc->vvcdsp.inter.sad[tab](lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
if (sad[dy][dx] < min_sad) {
min_sad = sad[dy][dx];
min_dx = dx;
diff --git a/libavcodec/vvc/inter_template.c b/libavcodec/vvc/inter_template.c
index a8068f4ba8..34485321d3 100644
--- a/libavcodec/vvc/inter_template.c
+++ b/libavcodec/vvc/inter_template.c
@@ -626,7 +626,11 @@ static void FUNC(ff_vvc_inter_dsp_init)(VVCInterDSPContext *const inter)
inter->apply_prof_uni_w = FUNC(apply_prof_uni_w);
inter->apply_bdof = FUNC(apply_bdof);
inter->prof_grad_filter = FUNC(prof_grad_filter);
- inter->sad = vvc_sad;
+ inter->sad[0] =
+ inter->sad[1] =
+ inter->sad[2] =
+ inter->sad[3] =
+ inter->sad[4] = vvc_sad;
}
#undef FUNCS
diff --git a/libavcodec/x86/vvc/vvc_sad.asm b/libavcodec/x86/vvc/vvc_sad.asm
index b468d89ac2..a20818530f 100644
--- a/libavcodec/x86/vvc/vvc_sad.asm
+++ b/libavcodec/x86/vvc/vvc_sad.asm
@@ -51,7 +51,7 @@ SECTION .text
INIT_YMM avx2
-cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+cglobal vvc_sad_8, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
movsxdifnidn dxq, dxd
movsxdifnidn dyq, dyd
@@ -76,10 +76,6 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
pxor m3, m3
vpbroadcastd m4, [pw_1]
- cmp block_wd, 16
- jge vvc_sad_16_128
-
- vvc_sad_8:
.loop_height:
movu xm0, [src1q]
vinserti128 m0, m0, [src1q + MAX_PB_SIZE * ROWS * 2], 1
@@ -100,7 +96,31 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
movd eax, xm0
RET
- vvc_sad_16_128:
+cglobal vvc_sad_16, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+ movsxdifnidn dxq, dxd
+ movsxdifnidn dyq, dyd
+
+ sub dxq, 2
+ sub dyq, 2
+
+ mov off1q, 2
+ mov off2q, 2
+
+ add off1q, dyq
+ sub off2q, dyq
+
+ shl off1q, 7
+ shl off2q, 7
+
+ add off1q, dxq
+ sub off2q, dxq
+
+ lea src1q, [src1q + off1q * 2 + 2 * 2]
+ lea src2q, [src2q + off2q * 2 + 2 * 2]
+
+ pxor m3, m3
+ vpbroadcastd m4, [pw_1]
+
sar block_wd, 4
.loop_height:
mov off1q, src1q
diff --git a/libavcodec/x86/vvc/vvcdsp_init.c b/libavcodec/x86/vvc/vvcdsp_init.c
index 4b4a2aa937..bd60963432 100644
--- a/libavcodec/x86/vvc/vvcdsp_init.c
+++ b/libavcodec/x86/vvc/vvcdsp_init.c
@@ -312,8 +312,20 @@ ALF_FUNCS(16, 12, avx2)
c->alf.classify = ff_vvc_alf_classify_##bd##_avx2; \
} while (0)
-int ff_vvc_sad_avx2(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
-#define SAD_INIT() c->inter.sad = ff_vvc_sad_avx2
+#define SAD_PROTOTYPE(w, opt) \
+int bf(ff_vvc_sad, w, opt)(const int16_t *src0, const int16_t *src1, \
+ int dx, int dy, int block_w, int block_h) \
+
+SAD_PROTOTYPE(8, avx2);
+SAD_PROTOTYPE(16, avx2);
+
+#define SAD_INIT(opt) do { \
+ c->inter.sad[0] = ff_vvc_sad_8_##opt; \
+ c->inter.sad[1] = \
+ c->inter.sad[2] = \
+ c->inter.sad[3] = \
+ c->inter.sad[4] = ff_vvc_sad_16_##opt; \
+} while (0)
#endif
void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
@@ -330,7 +342,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
ALF_INIT(8);
AVG_INIT(8, avx2);
MC_LINKS_AVX2(8);
- SAD_INIT();
+ SAD_INIT(avx2);
}
break;
case 10:
@@ -342,7 +354,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
AVG_INIT(10, avx2);
MC_LINKS_AVX2(10);
MC_LINKS_16BPC_AVX2(10);
- SAD_INIT();
+ SAD_INIT(avx2);
}
break;
case 12:
@@ -354,7 +366,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
AVG_INIT(12, avx2);
MC_LINKS_AVX2(12);
MC_LINKS_16BPC_AVX2(12);
- SAD_INIT();
+ SAD_INIT(avx2);
}
break;
default:
diff --git a/tests/checkasm/vvc_mc.c b/tests/checkasm/vvc_mc.c
index 1e889e2cff..deae1014d2 100644
--- a/tests/checkasm/vvc_mc.c
+++ b/tests/checkasm/vvc_mc.c
@@ -327,6 +327,7 @@ static void check_avg(void)
static void check_vvc_sad(void)
{
const int bit_depth = 10;
+ static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
VVCDSPContext c;
LOCAL_ALIGNED_32(uint16_t, src0, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
LOCAL_ALIGNED_32(uint16_t, src1, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
@@ -341,7 +342,7 @@ static void check_vvc_sad(void)
for (int w = 8; w <= MAX_CTU_SIZE; w *= 2) {
for(int offy = 0; offy <= 4; offy++) {
for(int offx = 0; offx <= 4; offx++) {
- if(check_func(c.inter.sad, "sad_%dx%d", w, h)) {
+ if(check_func(c.inter.sad[sad_tab[(w >> 3) - 1]], "sad_%dx%d", w, h)) {
int result0;
int result1;
--
2.45.1
More information about the ffmpeg-devel
mailing list