[FFmpeg-devel] [PATCH 13/18] ffv1/vulkan: redo context count tracking and quant_table_idx management

Lynne dev at lynne.ee
Sat Apr 12 10:22:44 EEST 2025


This commit also makes it possible for the encoder to choose a different
quantization table on a per-slice basis, as well as adding this capability
to the decoder.

Also, this commit fully fixes decoding of context=1 encoded files.
---
 libavcodec/ffv1_vulkan.h              |  2 +-
 libavcodec/ffv1enc_vulkan.c           |  6 ++++--
 libavcodec/vulkan/ffv1_common.comp    |  3 +--
 libavcodec/vulkan/ffv1_dec.comp       | 17 +++++++++--------
 libavcodec/vulkan/ffv1_dec_setup.comp |  1 -
 libavcodec/vulkan/ffv1_enc_setup.comp |  3 ++-
 libavcodec/vulkan/ffv1_reset.comp     | 11 ++++++-----
 libavcodec/vulkan_ffv1.c              | 22 ++++++++--------------
 8 files changed, 31 insertions(+), 34 deletions(-)

diff --git a/libavcodec/ffv1_vulkan.h b/libavcodec/ffv1_vulkan.h
index 1e0e6dd228..372478f4b7 100644
--- a/libavcodec/ffv1_vulkan.h
+++ b/libavcodec/ffv1_vulkan.h
@@ -49,9 +49,9 @@ typedef struct FFv1VkRCTParameters {
 } FFv1VkRCTParameters;
 
 typedef struct FFv1VkResetParameters {
+    uint32_t context_count[MAX_QUANT_TABLES];
     VkDeviceAddress slice_state;
     uint32_t plane_state_size;
-    uint32_t context_count;
     uint8_t codec_planes;
     uint8_t key_frame;
     uint8_t version;
diff --git a/libavcodec/ffv1enc_vulkan.c b/libavcodec/ffv1enc_vulkan.c
index 5409927589..688c14fb81 100644
--- a/libavcodec/ffv1enc_vulkan.c
+++ b/libavcodec/ffv1enc_vulkan.c
@@ -542,10 +542,12 @@ static int vulkan_encode_ffv1_submit_frame(AVCodecContext *avctx,
         pd_reset = (FFv1VkResetParameters) {
             .slice_state = slice_data_buf->address + f->slice_count*256,
             .plane_state_size = plane_state_size,
-            .context_count = context_count,
             .codec_planes = f->plane_count,
             .key_frame = f->key_frame,
         };
+        for (int i = 0; i < f->quant_table_count; i++)
+            pd_reset.context_count[i] = f->context_count[i];
+
         ff_vk_shader_update_push_const(&fv->s, exec, &fv->reset,
                                        VK_SHADER_STAGE_COMPUTE_BIT,
                                        0, sizeof(pd_reset), &pd_reset);
@@ -1071,9 +1073,9 @@ static int init_reset_shader(AVCodecContext *avctx, FFVkSPIRVCompiler *spv)
     GLSLD(ff_source_common_comp);
 
     GLSLC(0, layout(push_constant, scalar) uniform pushConstants {             );
+    GLSLF(1,    uint context_count[%i];                                        ,MAX_QUANT_TABLES);
     GLSLC(1,    u8buf slice_state;                                             );
     GLSLC(1,    uint plane_state_size;                                         );
-    GLSLC(1,    uint context_count;                                            );
     GLSLC(1,    uint8_t codec_planes;                                          );
     GLSLC(1,    uint8_t key_frame;                                             );
     GLSLC(1,    uint8_t version;                                               );
diff --git a/libavcodec/vulkan/ffv1_common.comp b/libavcodec/vulkan/ffv1_common.comp
index d2bd7e736e..64c1c2ce80 100644
--- a/libavcodec/vulkan/ffv1_common.comp
+++ b/libavcodec/vulkan/ffv1_common.comp
@@ -32,8 +32,7 @@ struct SliceContext {
     ivec2 slice_dim;
     ivec2 slice_pos;
     ivec2 slice_rct_coef;
-    u8vec4 quant_table_idx;
-    uint context_count;
+    u8vec3 quant_table_idx;
 
     uint hdr_len; // only used for golomb
 
diff --git a/libavcodec/vulkan/ffv1_dec.comp b/libavcodec/vulkan/ffv1_dec.comp
index ae0324cb26..a6272d4832 100644
--- a/libavcodec/vulkan/ffv1_dec.comp
+++ b/libavcodec/vulkan/ffv1_dec.comp
@@ -51,8 +51,8 @@ ivec2 get_pred(ivec2 sp, ivec2 off, int p, int sw, uint8_t quant_table_idx)
         (quant_table[quant_table_idx][4][127] != 0)) {
         TYPE cur2 = TYPE(0);
         if (off.x > 0) {
-            const ivec2 yoff_border2 = off.x == 1 ? ivec2(1, -1) : ivec2(0, 0);
-            cur2 = TYPE(imageLoad(dec[p], sp + LADDR(off + ivec2(-2,  0) + yoff_border2))[0]);
+            const ivec2 yoff_border2 = off.x == 1 ? ivec2(-1, -1) : ivec2(-2, 0);
+            cur2 = TYPE(imageLoad(dec[p], sp + LADDR(off + yoff_border2))[0]);
         }
         base += quant_table[quant_table_idx][3][(cur2 - cur) & MAX_QUANT_TABLE_MASK];
 
@@ -156,7 +156,7 @@ void decode_line_pcm(inout SliceContext sc, ivec2 sp, int w, int y, int p, int b
 
 void decode_line(inout SliceContext sc, ivec2 sp, int w,
                  int y, int p, int bits, uint64_t state,
-                 const int run_index)
+                 uint8_t quant_table_idx, const int run_index)
 {
 #ifndef RGB
     if (p > 0 && p < 3) {
@@ -167,7 +167,7 @@ void decode_line(inout SliceContext sc, ivec2 sp, int w,
 
     for (int x = 0; x < w; x++) {
         ivec2 pr = get_pred(sp, ivec2(x, y), p, w,
-                            sc.quant_table_idx[p]);
+                            quant_table_idx);
 
         int diff = get_isymbol(sc.c, state + CONTEXT_SIZE*abs(pr[0]));
         if (pr[0] < 0)
@@ -182,7 +182,7 @@ void decode_line(inout SliceContext sc, ivec2 sp, int w,
 
 void decode_line(inout SliceContext sc, ivec2 sp, int w,
                  int y, int p, int bits, uint64_t state,
-                 inout int run_index)
+                 uint8_t quant_table_idx, inout int run_index)
 {
 #ifndef RGB
     if (p > 0 && p < 3) {
@@ -198,7 +198,7 @@ void decode_line(inout SliceContext sc, ivec2 sp, int w,
         ivec2 pos = sp + ivec2(x, y);
         int diff;
         ivec2 pr = get_pred(sp, ivec2(x, y), p, w,
-                            sc.quant_table_idx[p]);
+                            quant_table_idx);
 
         VlcState sb = VlcState(state + VLC_STATE_SIZE*abs(pr[0]));
 
@@ -325,6 +325,7 @@ void decode_slice(inout SliceContext sc, const uint slice_idx)
     /* Arithmetic coding */
 #endif
     {
+        u8vec4 quant_table_idx = sc.quant_table_idx.xyyz;
         u64vec4 slice_state_off = (uint64_t(slice_state) +
                                    slice_idx*plane_state_size*codec_planes) +
                                   plane_state_size*uvec4(0, 1, 1, 2);
@@ -337,13 +338,13 @@ void decode_slice(inout SliceContext sc, const uint slice_idx)
 
             for (int y = 0; y < h; y++)
                 decode_line(sc, sp, w, y, p, bits,
-                            slice_state_off[p], run_index);
+                            slice_state_off[p], quant_table_idx[p], run_index);
         }
 #else
         for (int y = 0; y < sc.slice_dim.y; y++) {
             for (int p = 0; p < color_planes; p++)
                 decode_line(sc, sp, w, y, p, bits,
-                            slice_state_off[p], run_index);
+                            slice_state_off[p], quant_table_idx[p], run_index);
 
             writeout_rgb(sc, sp, w, y, true);
         }
diff --git a/libavcodec/vulkan/ffv1_dec_setup.comp b/libavcodec/vulkan/ffv1_dec_setup.comp
index a10163a8d6..5da63be56d 100644
--- a/libavcodec/vulkan/ffv1_dec_setup.comp
+++ b/libavcodec/vulkan/ffv1_dec_setup.comp
@@ -76,7 +76,6 @@ bool decode_slice_header(inout SliceContext sc, uint64_t state)
         if (idx >= quant_table_count)
             return true;
         sc.quant_table_idx[i] = uint8_t(idx);
-        sc.context_count = context_count[idx];
     }
 
     get_usymbol(sc.c, state);
diff --git a/libavcodec/vulkan/ffv1_enc_setup.comp b/libavcodec/vulkan/ffv1_enc_setup.comp
index 23f09b2af6..44c13404d8 100644
--- a/libavcodec/vulkan/ffv1_enc_setup.comp
+++ b/libavcodec/vulkan/ffv1_enc_setup.comp
@@ -38,6 +38,7 @@ void init_slice(out SliceContext sc, const uint slice_idx)
     sc.slice_rct_coef = ivec2(1, 1);
     sc.slice_coding_mode = int(force_pcm == 1);
     sc.slice_reset_contexts = sc.slice_coding_mode == 1;
+    sc.quant_table_idx = u8vec3(context_model);
 
     rac_init(sc.c,
              OFFBUF(u8buf, out_data, slice_idx * slice_size_max),
@@ -84,7 +85,7 @@ void write_slice_header(inout SliceContext sc, uint64_t state)
     put_symbol_unsigned(sc.c, state, 0);
 
     for (int i = 0; i < codec_planes; i++)
-        put_symbol_unsigned(sc.c, state, context_model);
+        put_symbol_unsigned(sc.c, state, sc.quant_table_idx[i]);
 
     put_symbol_unsigned(sc.c, state, pic_mode);
     put_symbol_unsigned(sc.c, state, sar.x);
diff --git a/libavcodec/vulkan/ffv1_reset.comp b/libavcodec/vulkan/ffv1_reset.comp
index 1b87ca754e..cfb7dcc444 100644
--- a/libavcodec/vulkan/ffv1_reset.comp
+++ b/libavcodec/vulkan/ffv1_reset.comp
@@ -28,14 +28,15 @@ void main(void)
         slice_ctx[slice_idx].slice_reset_contexts == false)
         return;
 
+    const uint8_t qidx = slice_ctx[slice_idx].quant_table_idx[gl_WorkGroupID.z];
+    uint contexts = context_count[qidx];
     uint64_t slice_state_off = uint64_t(slice_state) +
                                slice_idx*plane_state_size*codec_planes;
 
 #ifdef GOLOMB
     uint64_t start = slice_state_off +
-                     (gl_WorkGroupID.z*context_count +
-                      gl_LocalInvocationID.x)*VLC_STATE_SIZE;
-    for (uint x = gl_LocalInvocationID.x; x < context_count; x += gl_WorkGroupSize.x) {
+                     (gl_WorkGroupID.z*(plane_state_size/VLC_STATE_SIZE) + gl_LocalInvocationID.x)*VLC_STATE_SIZE;
+    for (uint x = gl_LocalInvocationID.x; x < contexts; x += gl_WorkGroupSize.x) {
         VlcState sb = VlcState(start);
         sb.drift     =  int16_t(0);
         sb.error_sum = uint16_t(4);
@@ -45,9 +46,9 @@ void main(void)
     }
 #else
     uint64_t start = slice_state_off +
-                     (gl_WorkGroupID.z*context_count)*CONTEXT_SIZE +
+                     gl_WorkGroupID.z*plane_state_size +
                      (gl_LocalInvocationID.x << 2 /* dwords */); /* Bytes */
-    uint count_total = context_count*(CONTEXT_SIZE /* bytes */ >> 2 /* dwords */);
+    uint count_total = contexts*(CONTEXT_SIZE /* bytes */ >> 2 /* dwords */);
     for (uint x = gl_LocalInvocationID.x; x < count_total; x += gl_WorkGroupSize.x) {
         u32buf(start).v = 0x80808080;
         start += gl_WorkGroupSize.x*(CONTEXT_SIZE >> 3 /* 1/8th of context */);
diff --git a/libavcodec/vulkan_ffv1.c b/libavcodec/vulkan_ffv1.c
index 5584b72385..aaebcd53b5 100644
--- a/libavcodec/vulkan_ffv1.c
+++ b/libavcodec/vulkan_ffv1.c
@@ -49,7 +49,6 @@ typedef struct FFv1VulkanDecodePicture {
     uint32_t plane_state_size;
     uint32_t slice_state_size;
     uint32_t slice_data_size;
-    uint32_t max_context_count;
 
     AVBufferRef *slice_offset_buf;
     uint32_t    *slice_offset;
@@ -77,8 +76,6 @@ typedef struct FFv1VulkanDecodeContext {
 } FFv1VulkanDecodeContext;
 
 typedef struct FFv1VkParameters {
-    uint32_t context_count[MAX_QUANT_TABLES];
-
     VkDeviceAddress slice_data;
     VkDeviceAddress slice_state;
     VkDeviceAddress scratch_data;
@@ -111,8 +108,6 @@ typedef struct FFv1VkParameters {
 static void add_push_data(FFVulkanShader *shd)
 {
     GLSLC(0, layout(push_constant, scalar) uniform pushConstants {  );
-    GLSLF(1,    uint context_count[%i];                             ,MAX_QUANT_TABLES);
-    GLSLC(0,                                                        );
     GLSLC(1,    u8buf slice_data;                                   );
     GLSLC(1,    u8buf slice_state;                                  );
     GLSLC(1,    u8buf scratch_data;                                 );
@@ -162,13 +157,15 @@ static int vk_ffv1_start_frame(AVCodecContext          *avctx,
     AVHWFramesContext *hwfc = (AVHWFramesContext *)avctx->hw_frames_ctx->data;
     enum AVPixelFormat sw_format = hwfc->sw_format;
 
+    int max_contexts;
     int is_rgb = !(f->colorspace == 0 && sw_format != AV_PIX_FMT_YA8) &&
                  !(sw_format == AV_PIX_FMT_YA8);
 
     fp->slice_num = 0;
 
+    max_contexts = 0;
     for (int i = 0; i < f->quant_table_count; i++)
-        fp->max_context_count = FFMAX(f->context_count[i], fp->max_context_count);
+        max_contexts = FFMAX(f->context_count[i], max_contexts);
 
     /* Allocate slice buffer data */
     if (f->ac == AC_GOLOMB_RICE)
@@ -176,7 +173,7 @@ static int vk_ffv1_start_frame(AVCodecContext          *avctx,
     else
         fp->plane_state_size = CONTEXT_SIZE;
 
-    fp->plane_state_size *= fp->max_context_count;
+    fp->plane_state_size *= max_contexts;
     fp->slice_state_size = fp->plane_state_size*f->plane_count;
 
     fp->slice_data_size = 256; /* Overestimation for the SliceContext struct */
@@ -430,8 +427,6 @@ static int vk_ffv1_end_frame(AVCodecContext *avctx)
 
     ff_vk_exec_bind_shader(&ctx->s, exec, &fv->setup);
     pd = (FFv1VkParameters) {
-        /* context_count */
-
         .slice_data = slices_buf->address,
         .slice_state = slice_state->address + f->slice_count*fp->slice_data_size,
         .scratch_data = tmp_data->address,
@@ -471,9 +466,6 @@ static int vk_ffv1_end_frame(AVCodecContext *avctx)
     else
         ff_vk_set_perm(sw_format, pd.fmt_lut, 0);
 
-    for (int i = 0; i < MAX_QUANT_TABLES; i++)
-        pd.context_count[i] = f->context_count[i];
-
     ff_vk_shader_update_push_const(&ctx->s, exec, &fv->setup,
                                    VK_SHADER_STAGE_COMPUTE_BIT,
                                    0, sizeof(pd), &pd);
@@ -505,12 +497,14 @@ static int vk_ffv1_end_frame(AVCodecContext *avctx)
     pd_reset = (FFv1VkResetParameters) {
         .slice_state = slice_state->address + f->slice_count*fp->slice_data_size,
         .plane_state_size = fp->plane_state_size,
-        .context_count = fp->max_context_count,
         .codec_planes = f->plane_count,
         .key_frame = f->picture.f->flags & AV_FRAME_FLAG_KEY,
         .version = f->version,
         .micro_version = f->micro_version,
     };
+    for (int i = 0; i < f->quant_table_count; i++)
+        pd_reset.context_count[i] = f->context_count[i];
+
     ff_vk_shader_update_push_const(&ctx->s, exec, reset_shader,
                                    VK_SHADER_STAGE_COMPUTE_BIT,
                                    0, sizeof(pd_reset), &pd_reset);
@@ -763,9 +757,9 @@ static int init_reset_shader(FFV1Context *f, FFVulkanContext *s,
     GLSLD(ff_source_common_comp);
 
     GLSLC(0, layout(push_constant, scalar) uniform pushConstants {             );
+    GLSLF(1,    uint context_count[%i];                                        ,MAX_QUANT_TABLES);
     GLSLC(1,    u8buf slice_state;                                             );
     GLSLC(1,    uint plane_state_size;                                         );
-    GLSLC(1,    uint context_count;                                            );
     GLSLC(1,    uint8_t codec_planes;                                          );
     GLSLC(1,    uint8_t key_frame;                                             );
     GLSLC(1,    uint8_t version;                                               );
-- 
2.47.2


More information about the ffmpeg-devel mailing list